Compare commits

..

63 Commits

Author SHA1 Message Date
Wing Lian
5b15816cf4 drop valueerror as this was from when 4bit required gptq 2024-08-22 19:16:32 -04:00
Wing Lian
fefa95e350 most model types now support flash attention 2 regardless of multipack support (#1854) 2024-08-22 16:39:23 -04:00
Wing Lian
b33dc07a77 rename nightly test and add badge (#1853) 2024-08-22 13:13:33 -04:00
Wing Lian
dcbff16983 run nightly ci builds against upstream main (#1851)
* run nightly ci builds against upstream main

* add test badges

* run the multigpu tests against nightly main builds too
2024-08-22 13:10:54 -04:00
Wing Lian
2f8037fee6 ensure that the hftrainer deepspeed config is set before the trainer class is ever init'ed (#1850) [skip ci] 2024-08-22 13:10:40 -04:00
Aman Gupta Karmani
de4ea2d1f2 docs: minor syntax highlight fix (#1839) 2024-08-22 11:47:34 -04:00
JohanWork
7ed92e61c2 fix: prompt phi (#1845) [skip ci]
* corecting phi system prompt

* phi test

* update

* add test
2024-08-22 11:46:57 -04:00
Wing Lian
9caa3eb699 make the train_on_eos default to turn so all eos tokens are treated the same (#1847) [skip ci] 2024-08-22 11:45:37 -04:00
Wing Lian
5b0b774e38 ensure that the bias is also in the correct dtype (#1848) [skip ci]
* ensure that the bias is also in the correct dtype

* add nightly for dpo-qlora-fsdp
2024-08-22 11:45:00 -04:00
Wing Lian
c3fc529bfc numpy 2.1.0 was released, but incompatible with numba (#1849) [skip ci] 2024-08-22 11:44:45 -04:00
Gal Cohen (galco)
957c956f89 rename jamba example (#1846) [skip ci]
* rename jamba example

* feat: change readme

---------

Co-authored-by: Gal Cohen <galc@ai21.com>
2024-08-22 09:22:55 -04:00
Aman Gupta Karmani
f07802f9fa examples: fix tiny-llama pretrain yml syntax (#1840) 2024-08-21 13:37:51 -04:00
Gal Cohen (galco)
9f917245f6 feat: add jamba chat_template (#1843)
* feat: add jamba chat_template

* fix: black

* feat: jamba fsdp+qlora

---------

Co-authored-by: Gal Cohen <galc@ai21.com>
2024-08-21 13:37:17 -04:00
Aman Gupta Karmani
649c19aba3 pretrain: fix with sample_packing=false (#1841) 2024-08-21 13:36:51 -04:00
Gal Cohen (galco)
5aac4bc284 fix: dont change quant storage dtype in case of fsdp (#1837)
* fix: dont change quant storage dtype in case of fsdp

* fix black

---------

Co-authored-by: Gal Cohen <galc@ai21.com>
2024-08-20 12:41:48 -04:00
Wing Lian
e29931259b optionally save the final FSDP model as a sharded state dict (#1828)
* efficiently save very large llms when using FSDP

* fix parsing and index of sharded chunks

* only save fsdp on main process

* debugging for rename

* save sharded state dict

* remove unused new param

* get state dict directly

* tweak acc merge fsdp to shard the weight files

* sharded_state_dict alongside save_safetensors seems to hang on checkpoint save
2024-08-19 14:59:24 -04:00
Wing Lian
b1d2921222 add validation to prevent 8bit lora finetuning on H100s (#1827) 2024-08-16 21:32:00 -04:00
Wing Lian
803fed3e90 update sklearn versrion, torch compile env vars, don't worry about failure on preprocess load model (#1821)
* update sklearn versrion, torch compile env vars, don't worry about failure on preprocess load model

* There is already a condition check within the function. This outer one is not necessary

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2024-08-16 10:41:51 -04:00
NanoCode012
68a3c7678a fix: parse model_kwargs (#1825) 2024-08-16 07:51:19 -04:00
NanoCode012
f18925fb4b fix: parse eager_attention (#1824) 2024-08-14 09:46:46 -04:00
Wing Lian
1853d6021d bump hf dependencies (#1823)
* bump hf dependencies

* revert optimum version change

* don't bump tokenizers all the way to 0.20 yet since transformers doesn't support that
2024-08-11 16:27:41 -04:00
Chiwan Park
0801f239cc fix the incorrect max_length for chat template (#1818) 2024-08-09 11:50:31 -04:00
Wing Lian
54392ac8a6 Attempt to run multigpu in PR CI for now to ensure it works (#1815) [skip ci]
* Attempt to run multigpu in PR CI for now to ensure it works

* fix yaml file

* forgot to include multigpu tests

* fix call to cicd.multigpu

* dump dictdefault to dict for yaml conversion

* use to_dict instead of casting

* 16bit-lora w flash attention, 8bit lora seems problematic

* add llama fsdp test

* more tests

* Add test for qlora + fsdp with prequant

* limit accelerate to 2 processes and disable broken qlora+fsdp+bnb test

* move multigpu tests to biweekly
2024-08-09 11:50:13 -04:00
Wing Lian
3e2b269d06 update tinyllama to use final instead of checkpoints (#1820) [skip ci] 2024-08-09 10:58:19 -04:00
Wing Lian
5ee4b7325f fix z3 leaf configuration when not using lists (#1817) [skip ci] 2024-08-09 10:54:52 -04:00
Wing Lian
70978467a0 skip no commit to main on ci (#1814) 2024-08-06 15:25:54 -04:00
Wing Lian
850f999a76 update peft and transformers (#1811) 2024-08-06 10:32:05 -04:00
Wing Lian
c56e0a79a5 logging improvements (#1808) [skip ci]
* logging improvements

* fix sort
2024-08-06 10:31:50 -04:00
Wing Lian
35d5e59d78 set z3 leaf for deepseek v2 (#1809) [skip ci]
* set z3 leaf for deepseek v2

* add deepseek v2 chat template
2024-08-06 09:30:46 -04:00
Wing Lian
fbbeb4fee0 remove un-necessary zero-first guard as it's already only called in a parent fn (#1810) [skip ci] 2024-08-06 09:29:23 -04:00
Wing Lian
ecdda006de One cycle lr (#1803)
* refactor one_cycle lr scheduler so it's reusable in more situations

* fix validation for lr_scheduler

* default to cosine anneal strategy

* one cycle lr exepects cos
2024-08-05 13:12:05 -04:00
Ben Feuer
b7665c26c8 Update conversation.qmd (#1788) [skip ci] 2024-08-05 12:44:26 -04:00
Aaditya Ura (looking for PhD Fall’24)
cb023c70db Update instruct-lora-8b.yml (#1789) [skip ci]
Config is giving an error if not using the end of the token as the `pad_to_sequence_len` is true.
2024-08-05 12:43:20 -04:00
ripes
7402eb9dcb Fix setting correct repo id when pushing dataset to hub (#1657)
* use the ds hash as the dataset's config_name

* improve logging for loading/pushing ds to hub

* fix missing f string
2024-08-05 12:42:15 -04:00
Sri Kainkaryam
203816f7b4 Fix colab example notebook (#1805) [skip ci] 2024-08-04 13:24:26 -04:00
Wing Lian
78b42a3fe1 fix roles to train defaults and make logging less verbose (#1801) 2024-07-30 20:58:17 -04:00
Wing Lian
3ebf22464b qlora-fsdp ram efficient loading with hf trainer (#1791)
* fix 405b with lower cpu ram requirements

* make sure to use doouble quant and only skip output embeddings

* set model attributes

* more fixes for sharded fsdp loading

* update the base model in example to use pre-quantized nf4-bf16 weights

* upstream fixes  for qlora+fsdp
2024-07-30 19:21:38 -04:00
Wing Lian
dbf8fb549e publish axolotl images without extras in the tag name (#1798) 2024-07-30 13:36:19 -04:00
Wing Lian
9a63884597 update test and main/nightly builds (#1797)
* update test and main/nightly builds

* don't install mamba-ssm on 2.4.0 since it has no wheels yet
2024-07-30 12:37:40 -04:00
Wing Lian
c5587b45ac use 12.4.1 instead of 12.4 [skip-ci] (#1796) 2024-07-30 08:50:23 -04:00
Wing Lian
d4f6a6b103 fix dockerfile and base builder (#1795) [skip-ci] 2024-07-30 08:34:37 -04:00
Wing Lian
d8d1788ffc move to supporting mostly 12.1 w 2.3.1 and add new 12.4 with 2.4.0 (#1793) 2024-07-30 08:06:11 -04:00
mhenrichsen
3bc8e64557 Update README.md (#1792) 2024-07-30 07:59:53 +02:00
Adam Brusselback
55cc214c76 Add flexible configuration options for chat_template dataset training (#1756)
* Add flexible configuration options for chat dataset training

- Introduce roles_to_train parameter to set training labels by role
- Add train_on_eos option to configure training on end-of-sequence tokens
- Implement per-message training configuration in dataset
- Allow fine-grained control over training specific portions of messages
- Add message_field_training and message_field_training_detail settings
- Implement mapping between dataset character offsets and tokenized prompt
- Enhance test suite to cover new functionality

* Fix missing field inits, things weren't working from yaml.

* Add flexible configuration options for chat dataset training

- Introduce roles_to_train parameter to set training labels by role
- Add train_on_eos option to configure training on end-of-sequence tokens
- Implement per-message training configuration in dataset
- Allow fine-grained control over training specific portions of messages
- Add message_field_training and message_field_training_detail settings
- Implement mapping between dataset character offsets and tokenized prompt
- Enhance test suite to cover new functionality

* Fix missing field inits, things weren't working from yaml.

* chore: lint

* Revert test repo back to NousResearch after opening PR to fix the tokenizer_config.json.

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-07-28 21:48:57 -04:00
Wing Lian
94ba93259f various batch of fixes (#1785)
* various batch of fixes

* more tweaks

* fix autoawq requirement for torch flexibility

* simplify conditionals

* multi-node fixes wip

* bump transformers and include 405b qlora+fsdp yaml
2024-07-28 07:25:54 -04:00
Wing Lian
22680913f3 Bump deepspeed 20240727 (#1790)
* pin deepspeed to 0.14.4 otherwise it doesn't play nice with trl

* Add test to import to try to trigger import dependencies
2024-07-27 10:24:11 -04:00
Wing Lian
6a9cfec222 add support for simpo via cpo trainer (#1772)
* add support for simpo via cpo trainer

* add cpo_alpha / sft_weight from the paper

* make sure to use the right builder for simpo
2024-07-23 21:22:16 -04:00
Wing Lian
fe250ada78 fix fsdp loading of models, esp 70b (#1780) 2024-07-23 19:54:28 -04:00
Wing Lian
e6b299dd79 bump flash attention to 2.6.2 (#1781) [skip ci] 2024-07-23 19:54:15 -04:00
Wing Lian
608a2f3180 bump transformers for updated llama 3.1 (#1778)
* bump transformers for updated llama 3.1

* bump for patch fix
2024-07-23 13:21:03 -04:00
Wing Lian
87455e7f32 swaps to use newer sample packing for mistral (#1773)
* swaps to use newer sample packing for mistral

* fix multipack patch test

* patch the common fa utils

* update for refactor of flash attn unpad

* remove un-needed drop attn mask for mistral

* bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2

* update test
2024-07-23 01:41:11 -04:00
Keith Stevens
985819d89b Add a chat_template prompt strategy for DPO (#1725)
* Implementing a basic chat_template strategy for DPO datasets

This mimics the sft chat_template strategy such that users can:
* Specify the messages field
* Specify the per message role and content fields
* speicfy the chosen and rejected fields
* Let the tokenizer construct the raw prompt
* Ensure the chosen and rejected fields don't have any prefix tokens

* Adding additional dpo chat template unittests

* Rename test class
2024-07-21 09:10:42 -04:00
Wing Lian
fa91b698e9 Fix untrained tokens (#1771)
* fix untrained reserved tokens

* save model after fixing untrained embeddings

* don't need fsdp conditional here
2024-07-19 12:21:37 -04:00
Wing Lian
e4063d60a7 bump transformers and set roundup_power2_divisions for more VRAM improvements, low bit ao optimizers (#1769)
* bump transformers and set roundup_power2_divisions for more VRAM improvements

* support for low bit optimizers from torch ao

* fix check for alternate optimizers and use nous models on hf for llama3

* add missing check for ao_adamw_fp8

* fix check when using custom optimizers w adamw
2024-07-19 00:47:07 -04:00
Wing Lian
7830fe04b5 Unsloth rope (#1767)
* Add unsloth rope embeddings support

* support for models weights in 4bit and do some memory gc

* use accelerate logger

* add unsloth llama rms norm optims

* update docs for unsloth

* more docs info
2024-07-18 14:54:41 -04:00
Wing Lian
c86c32a627 set the number of dataset processes on the DPO Config rather than the trainer (#1762) 2024-07-17 15:38:37 -04:00
Wing Lian
8731b95d04 re-enable PYTORCH_CUDA_ALLOC_CONF expandable_segments (#1765) [skip ci] 2024-07-17 15:38:26 -04:00
Wing Lian
8619b2d855 add torch_compile_mode options (#1763) [skip ci]
* add torch_compile_mode options

* make sure n_gpu is an int
2024-07-17 15:38:07 -04:00
Wing Lian
976f85195a fixes to accelerator so that iterable pretraining datasets work (#1759)
* fixes to accelerator so that iterable pretraining datasets work

* fix the pretraining test params

* split batches, not dispatch batches needs to be set

* update c4 datasets

* set epochs in pretrain config test

* need to set both split_batches and dispatch_batches to false for pretraining

* fix bool val in comment
2024-07-17 10:58:38 -04:00
Wing Lian
152ab76623 fix num gpu check (#1760) 2024-07-17 10:58:14 -04:00
Wing Lian
5f58555bd0 support for llama multipack using updated code/patches (#1754)
* support for llama multipack using updated code/patches

* also support unsloth patches

* incorrect arg

* add config validation for unsloth

* add missing return to validation

* add another missing return to validation
2024-07-16 17:36:29 -04:00
Wing Lian
cfc533a7f7 torch compile and cuda alloc improvements (#1755)
* enable experimental expandable_segments

* hf trainer seems to be missing torch compile

* disable PYTORCH_CUDA_ALLOC_CONF to see if that fixes cicd
2024-07-16 16:00:23 -04:00
Wing Lian
e1725aef2b update modal package and don't cache pip install (#1757)
* update modal package and cleanup pip cache

* more verbosity on the test
2024-07-16 14:45:38 -04:00
74 changed files with 3837 additions and 674 deletions

View File

@@ -12,36 +12,24 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "118"
cuda_version: 11.8.0
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.10"
pytorch: 2.1.2
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -67,6 +55,7 @@ jobs:
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDNN_VERSION=${{ matrix.cudnn_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}

View File

@@ -13,28 +13,22 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -65,6 +59,7 @@ jobs:
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -75,27 +70,22 @@ jobs:
strategy:
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -134,7 +124,7 @@ jobs:
matrix:
include:
- cuda: 121
cuda_version: 12.1.0
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:

52
.github/workflows/multi-gpu-e2e.yml vendored Normal file
View File

@@ -0,0 +1,52 @@
name: docker-multigpu-tests-biweekly
on:
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -12,28 +12,22 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -75,27 +69,22 @@ jobs:
strategy:
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

116
.github/workflows/tests-nightly.yml vendored Normal file
View File

@@ -0,0 +1,116 @@
name: Tests Nightly against upstream main
on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
env:
SKIP: no-commit-to-branch
pytest:
name: PyTest
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Update requirements.txt
run: |
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests

View File

@@ -26,6 +26,8 @@ jobs:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
env:
SKIP: no-commit-to-branch
pytest:
name: PyTest
@@ -57,6 +59,10 @@ jobs:
run: |
pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
@@ -68,27 +74,24 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -99,12 +102,13 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal jinja2
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal

View File

@@ -8,6 +8,8 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:

View File

@@ -1,5 +1,9 @@
# Axolotl
![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg)
![tests-nightly](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg)
![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg)
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
@@ -22,38 +26,49 @@ Features:
<td>
## Table of Contents
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Windows](#windows)
- [Mac](#mac)
- [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference-playground)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- [All Config Options](#all-config-options)
- Advanced Topics
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need Help?](#need-help-)
- [Badge](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing](#contributing-)
- [Sponsors](#sponsors-)
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Axolotl supports](#axolotl-supports)
- [Quickstart ⚡](#quickstart-)
- [Usage](#usage)
- [Advanced Setup](#advanced-setup)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu)
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [LambdaLabs](#lambdalabs)
- [GCP](#gcp)
- [Windows](#windows)
- [Mac](#mac)
- [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset)
- [Config](#config)
- [All Config Options](#all-config-options)
- [Train](#train)
- [Preprocess dataset](#preprocess-dataset)
- [Multi-GPU](#multi-gpu)
- [DeepSpeed](#deepspeed)
- [FSDP](#fsdp)
- [FSDP + QLoRA](#fsdp--qlora)
- [Weights \& Biases Logging](#weights--biases-logging)
- [Special Tokens](#special-tokens)
- [Inference Playground](#inference-playground)
- [Merge LORA to base](#merge-lora-to-base)
- [Common Errors 🧰](#common-errors-)
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need help? 🙋](#need-help-)
- [Badge ❤🏷️](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
- [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
- [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
- [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
</td>
<td>
@@ -95,6 +110,7 @@ Features:
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
@@ -333,7 +349,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config

View File

@@ -36,6 +36,7 @@ website:
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"

View File

@@ -8,6 +8,7 @@ ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
@@ -23,10 +24,17 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -2,5 +2,5 @@
set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ /workspace/axolotl/tests/e2e/

77
cicd/multigpu.py Normal file
View File

@@ -0,0 +1,77 @@
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

5
cicd/multigpu.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash
set -e
# only run one test at a time so as not to OOM the GPU
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/

View File

@@ -1,6 +1,8 @@
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -21,11 +23,12 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
"CUDA": os.environ.get("CUDA", "118"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
}
dockerfile_contents = df_template.render(**df_args)

View File

@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -3,7 +3,7 @@ ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -54,6 +54,14 @@ conversations where `from` is `prompter` `assistant` instead of default sharegpt
{"conversations": [{"from": "...", "value": "..."}]}
```
## sharegpt.load_ultrachat
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
```{.json filename="data.jsonl"}
{"messages": [{"user": "...", "assistant": "..."}]}
```
## sharegpt_jokes
creates a chat where bot is asked to tell a joke, then explain why the joke is funny

19
docs/torchao.qmd Normal file
View File

@@ -0,0 +1,19 @@
---
title: "PyTorch ao"
description: "Custom data types and layouts for training and inference"
---
### Installation
Stable Release from the PyTorch index
```bash
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
```
Nightly release
```bash
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
```

49
docs/unsloth.qmd Normal file
View File

@@ -0,0 +1,49 @@
---
title: "Unsloth"
description: "Hyper-optimized QLoRA finetuning for single GPUs"
---
### Overview
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
### Installation
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.
```bash
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
```
### Using unsloth w Axolotl
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
Our unsloth integration is currently limited to the following model architectures:
- llama
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
```yaml
unsloth_lora_mlp: true
unsloth_lora_qkv: true
unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
```yaml
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
```
### Limitations
- Single GPU only; e.g. no multi-gpu support
- No deepspeed or FSDP support (requires multi-gpu)
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
- No MoE support.

View File

@@ -43,7 +43,6 @@
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""

View File

@@ -6,5 +6,5 @@
- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
- ✅ qlora single-gpu, ~51GiB VRAM
- ✅ multipack
- FSDP
- FSDP
- ❓ 8-bit LoRA

View File

@@ -0,0 +1,61 @@
base_model: ai21labs/AI21-Jamba-1.5-Large
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
lora_target_linear: false
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

View File

@@ -0,0 +1,81 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
base_model: NousResearch/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
@@ -74,3 +74,5 @@ deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

View File

@@ -0,0 +1,63 @@
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

View File

@@ -72,4 +72,5 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:

View File

@@ -1,4 +1,4 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
base_model: TinyLlama/TinyLlama_v1.1
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

View File

@@ -1,6 +1,5 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
base_model: TinyLlama/TinyLlama_v1.1
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false

View File

@@ -9,9 +9,9 @@ strict: false
max_steps: 200
pretraining_dataset:
path: c4
name: en
type: pretrain
- path: allenai/c4
name: en
type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/model-out

View File

@@ -1,4 +1,4 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
base_model: TinyLlama/TinyLlama_v1.1
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

View File

@@ -1,18 +1,18 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.42.3
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
peft==0.12.0
transformers==4.44.0
tokenizers>=0.19.1
bitsandbytes==0.43.3
accelerate==0.33.0
datasets==2.20.0
deepspeed==0.14.4
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
datasets==2.19.1
flash-attn==2.6.1
flash-attn==2.6.3
sentencepiece
wandb
einops
@@ -21,23 +21,24 @@ optimum==1.16.2
hf_transfer
colorama
numba
numpy>=1.24.4
numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy
scikit-learn==1.2.2
scikit-learn==1.4.2
pynvml
art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
mamba-ssm==1.2.0.post1
# remote filesystems
s3fs
gcsfs
s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl==0.9.6

View File

@@ -80,13 +80,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.6.1",
"flash-attn==2.6.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
"deepspeed==0.14.4",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -108,7 +108,6 @@ setup(
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"q-galore-torch==1.0",
"torch-optimi==0.2.1",
],
},

View File

@@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -375,13 +375,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)

View File

@@ -0,0 +1,204 @@
"""
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
"""
import json
import logging
import os
import shutil
from pathlib import Path
from typing import Dict, Union
import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
import transformers
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from dotenv import load_dotenv
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""
A custom planner to cast tensors to bfloat16 on the fly during loading.
"""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16))
def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
):
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
"""
state_dict: Dict = {}
save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
no_dist=True,
)
# To handle if state is a dict like {model: {...}}
if len(state_dict.keys()) == 1:
state_dict = state_dict[list(state_dict)[0]]
# Ensure all tensors are in bfloat16
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Save index if sharded
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the model
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
if index is not None:
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
fout.write(content)
return save_path_
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process.
Args:
checkpoint_dir (`str`):
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
# Verify that the checkpoint directory exists
if not checkpoint_dir_.exists():
model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file."
if model_path_exists and optimizer_path_exists:
err += (
" However, potential model and optimizer checkpoint directories exist."
)
err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0"
err += "instead."
elif model_path_exists:
err += " However, a potential model checkpoint directory exists."
err += (
f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead."
)
elif optimizer_path_exists:
err += " However, a potential optimizer checkpoint directory exists."
err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead."
raise ValueError(err)
# To setup `save` to work
state = PartialState()
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path, safe_serialization
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
shutil.rmtree(checkpoint_dir_)
state.wait_for_everyone()
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(
config,
**kwargs,
)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
safe_serialization=True,
)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -2,6 +2,7 @@
CLI to run training on a model
"""
import logging
import warnings
from pathlib import Path
from typing import Union
@@ -76,8 +77,19 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
warnings.simplefilter("ignore")
with init_empty_weights(include_buffers=True):
# fmt: off
try:
AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
pass
# fmt: on
LOG.info(
Fore.GREEN

View File

@@ -0,0 +1,15 @@
"""
Common architecture specific constants
"""
MOE_ARCH_BLOCK = {
"dbrx": "DbrxFFN",
"jamba": "JambaSparseMoeBlock",
"jetmoe": [
"JetMoeMoA",
"JetMoeMoE",
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
}

View File

@@ -0,0 +1,150 @@
"""
helper functions for fixing the embeddings/tokenizer
"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import itertools
import numpy as np
import torch
@torch.inference_mode
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
"""
Many of the newer models have reserved tokens that are not trained.
"""
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
# Get untrained tokens
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained
# Get set and actual tokens
where_untrained = where_untrained.tolist()
if len(where_untrained) == 0:
return False
# Remove untrained indices where it's longer
where_untrained_set = frozenset(where_untrained)
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
# Remove None items in actual_bad_tokens
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
# Check if tokenizer and training datasets have bad tokens
if_bad_first = False
if_bad_second = False
# Check tokenizer's chat template for any untrained tokens
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is not None:
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
# Check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
if_bad = any(item in where_untrained_set for item in input_ids)
if if_bad:
if_bad_second = True
break
# Check last 250
if not if_bad_second:
left = max(size_dataset - 250, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
if_bad = any(item in where_untrained_set for item in input_ids)
if if_bad:
if_bad_second = True
break
# Check if bad tokens exists!
if not if_bad_first and not if_bad_second:
return False
# Count all the possible bad tokens
final_counts = np.zeros(
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
)
def mapping(examples):
input_ids = examples["input_ids"]
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
np.add.at(final_counts, counter, 1)
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
# Remove bad tokens
sum_embedding -= torch.sum(
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
)
sum_lm_head -= torch.sum(
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
)
# Find correct average by dividing by sum of trained tokens
mean_embedding = sum_embedding / n_trained
mean_lm_head = sum_lm_head / n_trained
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
mean_embedding = (
mean_embedding.repeat(
(
n_untrained,
1,
)
)
* scaling
)
mean_lm_head = (
mean_lm_head.repeat(
(
n_untrained,
1,
)
)
* scaling
)
where_null = scaling.ravel() == 0
mean_embedding[where_null] = 0
mean_lm_head[where_null] = 0
# Set them to the mean
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
# Clean up
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return True

View File

@@ -8,6 +8,7 @@ import importlib
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from collections import defaultdict
@@ -28,9 +29,18 @@ from transformers import (
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl import (
CPOConfig,
CPOTrainer,
DPOConfig,
DPOTrainer,
KTOConfig,
KTOTrainer,
ORPOConfig,
ORPOTrainer,
)
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
@@ -232,6 +242,12 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
@dataclass
@@ -265,7 +281,105 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
class AxolotlTrainer(Trainer):
@dataclass
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
"""
CPO config for CPO training
"""
simpo_gamma: Optional[float] = field(
default=None,
metadata={"help": "simpo gamma parameter"},
)
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: AxolotlTrainingArguments
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
@@ -290,11 +404,23 @@ class AxolotlTrainer(Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "q_galore_adamw8bit"]
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
):
return super().create_optimizer()
@@ -345,11 +471,23 @@ class AxolotlTrainer(Trainer):
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "q_galore_adamw8bit":
from q_galore_torch import QGaLoreAdamW8bit
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
QGaLoreAdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
if is_sagemaker_mp_enabled():
@@ -359,68 +497,6 @@ class AxolotlTrainer(Trainer):
return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
@@ -785,6 +861,14 @@ class AxolotlTrainer(Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -814,37 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "onecycle"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
@@ -884,7 +937,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer):
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
@@ -945,7 +998,7 @@ class AxolotlDPOTrainer(DPOTrainer):
return res
class AxolotlORPOTrainer(ORPOTrainer):
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
@@ -953,7 +1006,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(KTOTrainer):
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -961,6 +1014,14 @@ class AxolotlKTOTrainer(KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -1120,10 +1181,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
@@ -1173,7 +1230,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
@@ -1282,6 +1341,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
if self.cfg.torch_compile_mode:
training_arguments_kwargs[
"torch_compile_mode"
] = self.cfg.torch_compile_mode
# DDP Config
if self.cfg.ddp_timeout:
@@ -1367,12 +1430,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
] = self.cfg.lr_scheduler
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
@@ -1443,7 +1509,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.optimizer in ["optimi_adamw", "q_galore_adamw8bit"]:
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
@@ -1476,6 +1547,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
] = self.cfg.accelerator_config
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
@@ -1669,16 +1745,27 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
@@ -1686,7 +1773,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
@@ -1732,7 +1818,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
@@ -1740,14 +1825,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
if self.cfg.rl == "dpo":
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
@@ -1760,6 +1846,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):

View File

@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
def patch_llama_rms_norm():
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
@@ -104,30 +131,11 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
patch_llama_cross_entropy()
# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
patch_llama_rms_norm()
class FusedAttention(LlamaAttention):

View File

@@ -2,6 +2,7 @@
# pylint: disable=duplicate-code
import logging
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
)
def patch_mistral_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,

View File

@@ -10,11 +10,14 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"phi3",
"gemma",
"gemma2",
"gemmoe",
@@ -23,13 +26,36 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None):
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return
# retain for legacy
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
@@ -58,12 +84,6 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
def patch_remote(model_name, config_name, modeling_name):

View File

@@ -1,18 +1,20 @@
"""module for patching with unsloth optimizations"""
import inspect
import logging
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
@@ -97,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch():
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
if model_type == "llama":
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth fast_cross_entropy_loss")
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
@@ -179,12 +184,30 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth attn lora")
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def integrate_rope_embeddings():
import transformers.models.llama.modeling_llama
from unsloth.kernels.rope_embedding import fast_rope_embedding
def apply_rotary_pos_emb( # pylint: disable=unused-argument
q, # pylint: disable=invalid-name
k, # pylint: disable=invalid-name
cos,
sin,
position_ids=None,
unsqueeze_dim=1,
):
return fast_rope_embedding(q, k, cos, sin)
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
@@ -217,7 +240,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
@@ -243,9 +266,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
logging.warning(
"unable to apply unsloth lora qkv patch to layer %d", idx
)
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
@@ -264,6 +285,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
logging.warning(
LOG.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)
def patch_unsloth_layernorm():
try:
import transformers.models.llama.modeling_llama
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
class LlamaRMSNorm(nn.Module):
"""LlamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return Fast_RMS_Layernorm.apply(
hidden_states, self.weight, self.variance_epsilon, False
)
LOG.info("patching with unsloth.kernels.rms_layernorm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning("missing unsloth library")

View File

@@ -6,14 +6,16 @@ import logging
from typing import Any, Dict, List, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates
# Configure the logger
LOG = logging.getLogger("axolotl")
LOG.setLevel(logging.INFO)
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
"""Prompter for HF chat templates"""
def __init__(
self,
@@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter):
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
message_field_training: str = "train",
message_field_training_detail: str = "train_detail",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter):
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
@@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter):
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]
@@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter):
chat_template=self.chat_template,
)
def get_offsets_for_train_detail(
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
) -> List[int]:
tokenized_output = self.tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
tokens = tokenized_output.tokens()
token_offsets = tokenized_output["offset_mapping"]
LOG.debug(f"Tokenizing text: {text}")
LOG.debug(f"Tokens: {tokens}")
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
for i in range(len(token_offsets) - 1):
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
# Ensure the last token's end offset is set correctly
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
LOG.debug(f"Token offsets: {token_offsets}")
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
result = [IGNORE_TOKEN_ID] * len(token_offsets)
# Adjust train_details to align with token boundaries
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
for idx, (start, end) in enumerate(token_offsets):
for detail in adjusted_train_details:
# Check if the token is completely within the detail's range
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
if detail["train"] or not mask_untrainable:
result[idx] = start
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
else:
LOG.debug(
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
)
elif start < detail["end_offset"] and end > detail["begin_offset"]:
# Token partially overlaps with detail, always mark as non-trainable
LOG.debug(
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
)
LOG.debug(f"Final result: {result}")
return result
def adjust_train_details(
self, train_details: List[Dict], token_offsets: List[tuple]
) -> List[Dict]:
adjusted_details = []
for detail in train_details:
begin_offset = detail["begin_offset"]
end_offset = detail["end_offset"]
# Find the first token that starts after or at the begin_offset
begin_token = next(
(
i
for i, (t_start, t_end) in enumerate(token_offsets)
if t_start >= begin_offset
),
len(token_offsets),
)
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
begin_token -= 1
# Find the last token that ends before or at the end_offset
end_token = next(
(
i
for i in range(len(token_offsets) - 1, -1, -1)
if token_offsets[i][1] <= end_offset
),
-1,
)
if (
end_token < len(token_offsets) - 1
and token_offsets[end_token + 1][0] < end_offset
):
end_token += 1
if begin_token <= end_token:
adjusted_begin = token_offsets[begin_token][0]
adjusted_end = token_offsets[end_token][1]
if adjusted_begin != begin_offset or adjusted_end != end_offset:
LOG.warning(
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
)
adjusted_details.append(
{
"begin_offset": adjusted_begin,
"end_offset": adjusted_end,
"train": detail["train"],
}
)
else:
LOG.warning(
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
)
return adjusted_details
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
@@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
_messages = "conversations"
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos="last",
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
@property
def messages(self):
return self._messages
@@ -79,62 +201,170 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._messages = messages
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
last_eos_idx = -1
for index, turn in enumerate(turns):
role = turn.get(self.prompter.message_field_role)
content = turn.get(self.prompter.message_field_content)
train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
tokenized_prompt = {
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)
should_train = (
train_turn
if train_turn is not None
else bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
LOG.debug(f"Should train: {should_train}")
turn_start_idx, turn_end_idx = self.find_turn(
conversation_ids=input_ids, turn=index, turn_content=turn
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail(
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
input_ids
):
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle EOS token
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
if eos_idx == turn_end_idx:
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
):
labels[eos_idx] = input_ids[eos_idx]
LOG.debug(f"EOS token set for training at index {eos_idx}")
else:
LOG.debug(
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for train_on_eos
if self.train_on_eos == "last" and last_eos_idx != -1:
labels[last_eos_idx] = input_ids[last_eos_idx]
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
LOG.debug(f"Final labels: {labels}")
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
def find_eos_token(self, input_ids, start_idx):
eos_token_id = self.tokenizer.eos_token_id
for i in range(start_idx, len(input_ids)):
if input_ids[i] == eos_token_id:
return i
return -1
def find_turn(self, conversation_ids, turn, turn_content):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get(self.prompter.message_field_content, "")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else:
end_idx = -1
return start_idx, end_idx
def get_conversation_thread(self, prompt):
return prompt[self.messages]
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
message_field_role = (
ds_cfg["message_field_role"]
if ds_cfg and "message_field_role" in ds_cfg
else "from"
)
message_field_content = (
ds_cfg["message_field_content"]
if ds_cfg and "message_field_content" in ds_cfg
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
drop_system_message = (
ds_cfg["drop_system_message"]
if ds_cfg and "drop_system_message" in ds_cfg
else False
)
ds_cfg = ds_cfg or {}
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
"max_length": cfg.sequence_len,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_templates(chat_template),
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
drop_system_message=drop_system_message,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -0,0 +1,78 @@
"""
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import chat_templates
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
field_message_role = ds_cfg.get("message_field_role", "role")
field_message_content = ds_cfg.get("message_field_content", "content")
role_map_inv = ds_cfg.get(
"roles",
{
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
)
role_map = {}
for target, sources in role_map_inv.items():
for source in sources:
role_map[source] = target
def transform_fn(sample, tokenizer=None):
messages = sample[field_messages]
messages = [
{
"role": role_map[m[field_message_role]],
"content": m[field_message_content],
}
for m in messages
]
chosen = {
"role": role_map[sample[field_chosen][field_message_role]],
"content": sample[field_chosen][field_message_content],
}
rejected = {
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[chosen],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
return result
return transform_fn

View File

@@ -65,8 +65,10 @@ class AlpacaPrompter(Prompter):
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
elif self.prompt_style == PromptStyle.PHI.value:
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
self.system_format = "<|system|>{system}\n"
self.turn_no_input_format = (
"<|user|>\n{instruction}<|end|>\n<|assistant|>\n"
)
self.system_format = "<|system|>\n{system}<|end|>\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input

View File

@@ -12,6 +12,7 @@ import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
@@ -19,6 +20,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
@@ -52,6 +54,15 @@ class TrainDatasetMeta:
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# enable expandable segments for cuda allocation to improve VRAM usage
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
# load the tokenizer first
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -114,6 +125,13 @@ def train(
total_num_steps,
)
if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
@@ -177,9 +195,12 @@ def train(
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
if cfg.fsdp_final_state_dict_type:
state_dict_type = cfg.fsdp_final_state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
@@ -191,30 +212,38 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
if (
state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
):
save_fsdp_model(
trainer.accelerator.state.fsdp_plugin,
trainer.accelerator,
trainer.model,
cfg.output_dir,
)
elif state_dict_type == "FULL_STATE_DICT":
trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
trainer.save_model(cfg.output_dir)
# the trainer saved a model.safetensors file in the output directory,
# but it is a proxy model and should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
# but it is most likely a proxy model and if so, should be deleted
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
maybe_sharded = os.path.exists(
os.path.join(cfg.output_dir, "model.safetensors.index.json")
)
if maybe_proxy and maybe_sharded:
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted")
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
try:
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)

File diff suppressed because one or more lines are too long

View File

@@ -7,6 +7,7 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
@@ -77,6 +78,7 @@ class PretrainingDataset(BaseModel):
split: Optional[str] = "train"
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
class UserDefinedPrompterType(BaseModel):
@@ -114,10 +116,16 @@ class SFTDataset(BaseModel):
field_messages: Optional[str] = None
message_field_role: Optional[str] = None
message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
@@ -158,6 +166,7 @@ class KTODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
class RLType(str, Enum):
@@ -167,6 +176,7 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -179,6 +189,8 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -225,6 +237,12 @@ class LoraConfig(BaseModel):
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None
@@ -304,6 +322,8 @@ class ModelInputConfig(BaseModel):
)
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
@@ -343,7 +363,13 @@ class HyperparametersConfig(BaseModel):
optimizer: Optional[
Union[
OptimizerNames,
Literal["lion_pytorch", "optimi_adamw", "q_galore_adamw8bit"],
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
@@ -356,7 +382,7 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = "cosine"
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
@@ -507,6 +533,8 @@ class AxolotlInputConfig(
dataloader_prefetch_factor: Optional[int] = None
dataloader_drop_last: Optional[bool] = None
accelerator_config: Optional[Dict[str, Any]] = None
remove_unused_columns: Optional[bool] = None
push_dataset_to_hub: Optional[str] = None
@@ -589,14 +617,21 @@ class AxolotlInputConfig(
flash_attn_fuse_mlp: Optional[bool] = None
flash_optimum: Optional[bool] = None
eager_attention: Optional[bool] = None
unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None
unsloth_lora_o: Optional[bool] = None
unsloth_rms_norm: Optional[bool] = None
unsloth_rope: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
fsdp_final_state_dict_type: Optional[
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
] = None
val_set_size: Optional[float] = Field(default=0.0)
@@ -605,6 +640,9 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
@@ -626,6 +664,8 @@ class AxolotlInputConfig(
orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None
simpo_gamma: Optional[float] = None
cpo_alpha: Optional[float] = None
kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None
@@ -640,6 +680,8 @@ class AxolotlInputConfig(
chat_template: Optional[ChatTemplate] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
@@ -705,6 +747,24 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_split_batches_accelerate(cls, data):
# alternatively set ACCELERATE_SPLIT_BATCHES=False
if data.get("pretraining_dataset"):
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):
@@ -823,7 +883,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
not self.optimizer or "adamw" not in self.optimizer.value
not self.optimizer or "adamw" not in str(self.optimizer).lower()
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@@ -894,6 +954,8 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits
if (
data.get("sample_packing")
and data.get("eval_table_size")
@@ -1090,6 +1152,20 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
if (
data.get("fsdp")
and data.get("save_safetensors")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return data
@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
@@ -1115,6 +1191,55 @@ class AxolotlInputConfig(
raise ValueError("either datasets or pretraining_dataset is required")
return data
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data
@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
@@ -1160,9 +1285,37 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
@model_validator(mode="before")
@classmethod
def check_hopper_8bit_lora(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_deepspeed(cls, data):
if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data

View File

@@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
) -> Dict[str, List]:
res = tokenizer(
examples,
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,

View File

@@ -1,4 +1,5 @@
"""data handling specific to DPO"""
import inspect
import logging
from functools import partial

View File

@@ -42,7 +42,7 @@ from axolotl.prompters import (
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
with zero_first(is_local_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
@@ -160,8 +160,12 @@ def load_tokenized_prepared_datasets(
use_auth_token = cfg.hf_use_auth_token
try:
if cfg.push_dataset_to_hub:
LOG.info(
f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
cfg.push_dataset_to_hub,
ds_hash,
token=use_auth_token,
)
dataset = dataset[split]
@@ -170,6 +174,7 @@ def load_tokenized_prepared_datasets(
# pylint: disable=duplicate-code
if dataset:
# This is for the case where we already loaded a pretokenized dataset from the hub
...
elif (
cfg.dataset_prepared_path
@@ -180,7 +185,14 @@ def load_tokenized_prepared_datasets(
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...")
else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
if cfg.push_dataset_to_hub:
LOG.info("Unable to find prepared dataset in Huggingface hub")
if cfg.is_preprocess:
LOG.info(
f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..."
)
else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
@@ -198,6 +210,8 @@ def load_tokenized_prepared_datasets(
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
@@ -208,6 +222,8 @@ def load_tokenized_prepared_datasets(
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -428,10 +444,12 @@ def load_tokenized_prepared_datasets(
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
cfg.push_dataset_to_hub,
ds_hash,
private=True,
)
return dataset, prompters

View File

@@ -44,6 +44,10 @@ def is_main_process():
return dist.get_rank() == 0
def is_local_main_process():
return PartialState().is_main_process
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@@ -149,11 +153,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
)
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device()
0.0, device=torch.cuda.current_device(), dtype=torch.float32
) # Placeholder tensor
# Broadcast the tensor to all processes.

View File

@@ -13,6 +13,7 @@ from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.quantizers import AutoHfQuantizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
@@ -173,6 +174,7 @@ def load_sharded_model_quant(
low_memory=True,
verbose=False,
loading_workers=2,
quantization_config=None,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
@@ -186,15 +188,26 @@ def load_sharded_model_quant(
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
else:
# this is the more common case with HF transformers
# TODO can we detect the model arch and dynamically set skip_modules
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
model.is_loaded_in_4bit = True
@@ -251,6 +264,11 @@ def load_sharded_model_quant(
quant_method=quant_method,
)
# these attributes are needed to inform transformers/peft of the quantization
model.is_quantized = True
model.quantization_method = "bitsandbytes"
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading

View File

@@ -1,7 +1,7 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import gc
import logging
import math
import os
@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
PreTrainedModel,
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -94,12 +96,6 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
if (
cfg.adapter
@@ -346,7 +342,36 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
patch_for_multipack(
cfg.model_config_type,
model_name=cfg.base_model,
is_remote_code=cfg.trust_remote_code,
)
if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
elif cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
)
integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
@@ -399,7 +424,7 @@ def load_model(
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch()
integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -407,23 +432,12 @@ def load_model(
patch_self_attn_lora()
# Modify mistral derived models
if (
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
patch_mistral_cross_entropy,
)
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
patch_mistral_cross_entropy()
model_kwargs: Dict[str, Any] = {}
@@ -496,7 +510,25 @@ def load_model(
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
if cfg.adapter == "qlora" and cfg.load_in_4bit:
if (
cfg.adapter in ["qlora", "lora"]
and hasattr(model_config, "quantization_config")
and model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
):
if model_config.quantization_config["quant_method"] == "gptq":
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "awq":
model_kwargs["quantization_config"] = AwqConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**model_config.quantization_config
)
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -506,7 +538,9 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
cfg.deepspeed or cfg.fsdp
):
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -551,16 +585,10 @@ def load_model(
"flash_attention_2"
)
else:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
@@ -590,14 +618,21 @@ def load_model(
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
):
quant_storage = cfg.torch_dtype
quantization_config = hasattr(
model_config, "quantization_config"
) and getattr(model_config, "quantization_config")
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
@@ -605,7 +640,7 @@ def load_model(
and not cfg.trust_remote_code
and not cfg.gptq
):
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
@@ -687,7 +722,7 @@ def load_model(
**model_kwargs,
)
else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in model_kwargs:
@@ -771,12 +806,16 @@ def load_model(
set_z3_leaf_modules,
)
if cfg.model_config_type == "mixtral":
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "dbrx":
moe_block = get_module_class_from_name(model, "DbrxFFN")
set_z3_leaf_modules(model, [moe_block])
if cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[cfg.model_config_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
model,
[
get_module_class_from_name(model, module_name)
for module_name in moe_blocks
],
)
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -790,6 +829,9 @@ def load_model(
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable(
@@ -824,6 +866,9 @@ def load_model(
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
if (
cfg.ddp
and not load_in_8bit
@@ -863,6 +908,15 @@ def load_model(
integrate_lora_patch(model, cfg)
if cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
# TODO resume_from_checkpoint handling
return model, lora_config
@@ -960,7 +1014,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {}
@@ -1036,9 +1090,20 @@ def load_lora(model, cfg, inference=False, config_only=False):
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
weight_mismatch = False
bias_mismatch = False
try:
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
weight_mismatch = module.weight.dtype != dtype
except AttributeError:
pass
try:
bias_mismatch = module.bias.dtype != dtype
except AttributeError:
pass
if weight_mismatch:
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
if bias_mismatch:
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
if weight_mismatch or bias_mismatch:
module.to(dtype)

View File

@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens)
for token in tokenizer.encode(tokens, add_special_tokens=False)
]
return colored_tokens

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions"""
import json
import math
import os
import random
@@ -15,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl")
@@ -182,90 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if (
cfg.is_mistral_derived_model and cfg.flash_attention
) or cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
train_dataset = train_dataset.filter(
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
return train_dataset, eval_dataset
@@ -391,6 +390,26 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
return total_num_steps
def setup_torch_compile_env(cfg):
if cfg.torch_compile:
if not cfg.torch_compile_backend:
os.environ["ACCELERATE_DYNAMO_BACKEND"] = "INDUCTOR"
else:
os.environ["ACCELERATE_DYNAMO_BACKEND"] = cfg.torch_compile_backend.upper()
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
HfTrainerDeepSpeedConfig(cfg.deepspeed)
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing:
@@ -417,8 +436,16 @@ def prepare_optim_env(cfg):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
stage = None
# check if the cfg.deepspeed is a file
if os.path.isfile(cfg.deepspeed):
# parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin)
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
setup_torch_compile_env(cfg)
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
@@ -426,8 +453,14 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

View File

@@ -0,0 +1,341 @@
"""
E2E tests for multigpu lora tinyllama
"""
import logging
import os
import unittest
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
class TestMultiGPULlama(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_lora_ddp_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 50,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp": [
"full_shard",
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp": [
"full_shard",
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@pytest.mark.skip("disabled due to upstream issue")
@with_temp_dir
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
"tokenizer_type": "AutoTokenizer",
"adapter": "qlora",
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": [
"embed_tokens",
"lm_head",
],
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|end_of_text|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:25%]",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp": [
"full_shard",
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -0,0 +1,98 @@
"""
E2E tests for multigpu qwen2
"""
import logging
import os
import unittest
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
class TestMultiGPUQwen2(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_qlora_fsdp_dpo(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2-1.5B",
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 100,
"warmup_steps": 20,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": "auto",
"tf32": True,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {
"use_reentrant": False,
},
"fsdp": [
"full_shard",
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
},
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
import unittest
import transformers
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
load_model(cfg, tokenizer, inference=cli_args.inference)
assert (
"axolotl.monkeypatch.mistral_attn_hijack_flash"
in model.model.layers[0].self_attn.forward.__module__
"torch.jit"
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
)

20
tests/e2e/test_imports.py Normal file
View File

@@ -0,0 +1,20 @@
"""
test module to import various submodules that have historically broken due to dependency issues
"""
import unittest
class TestImports(unittest.TestCase):
"""
Test class to import various submodules that have historically broken due to dependency issues
"""
def test_import_causal_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
HFCausalTrainerBuilder,
)
def test_import_rl_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
HFRLTrainerBuilder,
)

View File

@@ -0,0 +1,67 @@
"""
E2E tests for llama pretrain
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase):
"""
Test case for Llama models w pretraining
"""
@with_temp_dir
def test_pretrain_w_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"pretraining_dataset": [
{
"path": "allenai/c4",
"name": "en",
"type": "pretrain",
}
],
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -65,45 +65,3 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_q_galore_adamw8bit(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "q_galore_adamw8bit",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -2,6 +2,7 @@
tests for chat_template prompt strategy
"""
import logging
import unittest
import pytest
@@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy,
load,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
{
"role": "assistant",
"content": "goodbye",
},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hello"},
{"role": "user", "content": "goodbye"},
{"role": "assistant", "content": "goodbye"},
]
}
]
@@ -53,22 +45,28 @@ def fixture_sharegpt_dataset():
[
{
"conversations": [
{
"from": "human",
"value": "hello",
},
{
"from": "gpt",
"value": "hello",
},
{
"from": "human",
"value": "goodbye",
},
{
"from": "gpt",
"value": "goodbye",
},
{"from": "human", "value": "hello"},
{"from": "gpt", "value": "hello"},
{"from": "human", "value": "goodbye"},
{"from": "gpt", "value": "goodbye"},
]
}
]
)
@pytest.fixture(name="basic_dataset")
def fixture_basic_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"conversations": [
{"from": "system", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "assistant", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "assistant", "value": "I'm doing well, thank you!"},
]
}
]
@@ -77,19 +75,611 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer
class TestChatTemplateConfigurations:
"""
Test class for various configurations of ChatTemplateStrategy.
"""
@staticmethod
def find_sublist(full_list, sub_list):
token_count = len(sub_list)
for index in range(len(full_list) - token_count + 1):
if full_list[index : index + token_count] == sub_list:
return index
return -1
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Check the behavior of human inputs
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
labeled = all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
)
LOG.debug(
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
)
LOG.debug("Full labels: %s", labels)
LOG.debug("Full input_ids: %s", input_ids)
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Verify that human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
LOG.debug(
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["human", "assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that all responses are labeled (except for special tokens)
all_responses = [
"Hello",
"Hi there!",
"How are you?",
"I'm doing well, thank you!",
]
for response in all_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
train_on_eos="none", # Add this line
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
# Verify that no labels are set when roles_to_train is empty
LOG.debug("Full labels: %s", labels)
assert all(
label == IGNORE_TOKEN_ID for label in labels
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="all",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to be labeled"
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="turn",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
eos_idx = start_idx + len(response_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert eos_idx < len(
input_ids
), f"Could not find EOS token after '{response}'"
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token after assistant response '{response}' to be labeled"
# Check that EOS tokens after human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
eos_idx = start_idx + len(input_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token after human input '{input_text}' to not be labeled"
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="last",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
last_eos_idx = eos_indices[-1]
# Check that only the last EOS token is labeled
for idx in eos_indices[:-1]:
assert (
labels[idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {idx} to not be labeled"
assert (
labels[last_eos_idx] != IGNORE_TOKEN_ID
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="none",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to not be labeled"
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with drop_system_message=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
# Check if system message is not present in input_ids
system_message = "You are an AI assistant."
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
assert (
self.find_sublist(input_ids, system_ids) == -1
), "Expected system message to be dropped"
def test_custom_roles(self, llama3_tokenizer):
LOG.info("Testing with custom roles mapping")
custom_roles = {
"user": ["human", "user"],
"assistant": ["ai", "assistant"],
"system": ["context"],
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["ai"],
)
# Create a new dataset with modified role names
modified_conversations = [
{"from": "context", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "ai", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "ai", "value": "I'm doing well, thank you!"},
]
modified_dataset = Dataset.from_dict(
{"conversations": [modified_conversations]}
)
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Check if AI responses are labeled correctly
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in ai_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for AI response '{response}' to be set"
# Check if human messages are not labeled
human_messages = ["Hello", "How are you?"]
for message in human_messages:
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, message_ids)
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(message_ids)]
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
def test_message_field_training(self, llama3_tokenizer):
LOG.info("Testing with message_field_training")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
)
# Create a new dataset with the train and train_detail fields
modified_conversation = [
{"from": "system", "value": "You are an AI assistant.", "train": False},
{"from": "human", "value": "Hello", "train": False},
{"from": "assistant", "value": "Hello", "train": True},
{"from": "human", "value": "How are you?", "train": True},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": False},
{"begin_offset": 9, "end_offset": 18, "train": True},
{"begin_offset": 19, "end_offset": 30, "train": False},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": False,
},
{"from": "assistant", "value": "Hi there!", "train": True},
]
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Function to find all occurrences of a sublist
def find_all_sublists(full_list, sub_list):
indices = []
for index in range(len(full_list) - len(sub_list) + 1):
if full_list[index : index + len(sub_list)] == sub_list:
indices.append(index)
return indices
# Keep track of which occurrences we've processed
processed_occurrences = {}
# Check if messages are labeled correctly based on train or train_detail
for i, turn in enumerate(modified_conversation):
turn_tokens = llama3_tokenizer.encode(
turn["value"], add_special_tokens=False
)
occurrences = find_all_sublists(input_ids, turn_tokens)
turn_key = turn["value"]
if turn_key not in processed_occurrences:
processed_occurrences[turn_key] = 0
current_occurrence = processed_occurrences[turn_key]
if current_occurrence >= len(occurrences):
assert (
False
), f"Not enough occurrences found for message: {turn['value']}"
start_idx = occurrences[current_occurrence]
processed_occurrences[turn_key] += 1
end_idx = start_idx + len(turn_tokens)
LOG.debug(
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
)
if "train_detail" in turn:
# Get token offsets
tokenized_output = llama3_tokenizer(
turn["value"], return_offsets_mapping=True, add_special_tokens=False
)
token_offsets = tokenized_output["offset_mapping"]
# Adjust token offsets as done in the implementation
for i in range(len(token_offsets) - 1):
token_offsets[i] = (
token_offsets[i][0],
token_offsets[i + 1][0] - 1,
)
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
# Adjust train_details
adjusted_train_details = strategy.prompter.adjust_train_details(
turn["train_detail"], token_offsets
)
LOG.debug(f"Original train_details: {turn['train_detail']}")
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
# Handle train_detail
token_offsets = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=False,
)
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=True,
)
LOG.debug(f"Token offsets: {token_offsets_masked}")
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
for i, offset in enumerate(token_offsets_masked):
if offset != IGNORE_TOKEN_ID:
expected_labels[i] = turn_tokens[i]
actual_labels = labels[
start_idx : start_idx + len(token_offsets_masked)
]
assert (
actual_labels == expected_labels
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
for detail in adjusted_train_details:
# Find the token indices that correspond to the character offsets
detail_start = start_idx + next(
i
for i, offset in enumerate(token_offsets)
if offset >= detail["begin_offset"]
)
detail_end = start_idx + next(
(
i
for i, offset in enumerate(token_offsets)
if offset > detail["end_offset"]
),
len(token_offsets),
)
detail_text = turn["value"][
detail["begin_offset"] : detail["end_offset"] + 1
]
detail_labels = labels[detail_start:detail_end]
detail_input_ids = input_ids[detail_start:detail_end]
LOG.debug(
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
)
LOG.debug(f"Detail input_ids: {detail_input_ids}")
LOG.debug(f"Detail labels: {detail_labels}")
LOG.debug(
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
)
LOG.debug(
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
)
if detail["train"]:
assert all(
label != IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
should_train = turn.get("train", False)
turn_labels = labels[start_idx:end_idx]
LOG.debug(f"Should train: {should_train}")
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
LOG.debug(f"Turn labels: {turn_labels}")
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
LOG.debug(
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
)
if should_train:
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be set\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
else:
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
LOG.debug(
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
f"start_idx: {start_idx}, end_idx: {end_idx}, "
f"labels: {labels[start_idx:end_idx]}"
)
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
class TestAssistantChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
LOG.info("Loading llama-3 tokenizer with assistant dataset")
strategy = load(
llama3_tokenizer,
DictDefault(
@@ -115,21 +705,26 @@ class TestAssistantChatTemplateLlama3:
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
assert input_ids == [
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
LOG.info("Testing llama-3 with assistant dataset")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -142,15 +737,16 @@ class TestAssistantChatTemplateLlama3:
"system": ["system"],
},
),
llama3_tokenizer,
False,
512,
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
assert input_ids == [
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
@@ -162,6 +758,64 @@ class TestAssistantChatTemplateLlama3:
271, 19045, 29474, 128009,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
roles={
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
prompt_tokens = strategy.prompter.build_prompt(
assistant_dataset[0]["messages"], False
)
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
LOG.debug(f"Generated prompt: {prompt}")
res = strategy.tokenize_prompt(assistant_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# fmt: off
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert labels == expected_labels, (
f"Labels mismatch:\n"
f"Expected: {expected_labels}\n"
f"Actual: {labels}\n"
f"Input IDs: {input_ids}\n"
)
class TestSharegptChatTemplateLlama3:
@@ -169,30 +823,160 @@ class TestSharegptChatTemplateLlama3:
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3(self, llama3_tokenizer, sharegpt_dataset):
# pylint: disable=duplicate-code
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
llama3_tokenizer,
False,
512,
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["gpt"],
)
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
assert input_ids == [
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["human"],
)
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["system", "human"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 9125, 128007,
271, 2675, 527, 459, 15592, 18328, 13, 128009,
128006, 882, 128007, # user header
271, 9906, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 13347, 1070, 0, 128009, # assistant response eot
128006, 882, 128007,
271, 4438, 527, 499, 30, 128009,
128006, 78191, 128007,
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,156 @@
"""
tests for chat_template prompt strategy
"""
import unittest
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
],
"chosen": {
"role": "assistant",
"content": "goodbye",
},
"rejected": {
"role": "assistant",
"content": "party on",
},
}
]
)
@pytest.fixture(name="custom_assistant_dataset")
def fixture_custom_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"conversation": [
{
"speaker": "human",
"text": "hello",
},
{
"speaker": "agent",
"text": "hello",
},
{
"speaker": "human",
"text": "goodbye",
},
],
"better": {
"speaker": "agent",
"text": "goodbye",
},
"worse": {
"speaker": "agent",
"text": "party on",
},
}
]
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
"message_field_role": "speaker",
"message_field_content": "text",
"roles": {
"user": ["human"],
"assistant": ["agent"],
"system": ["sys"],
},
}
],
}
)
)
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
if __name__ == "__main__":
unittest.main()

View File

@@ -192,6 +192,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
# pylint: disable=duplicate-code
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header
@@ -228,6 +229,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
# pylint: disable=duplicate-code
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header

View File

@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)

View File

@@ -24,7 +24,7 @@ class TestPretrainingPacking(unittest.TestCase):
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
"c4",
"allenai/c4",
"en",
streaming=True,
)["train"]
@@ -33,7 +33,7 @@ class TestPretrainingPacking(unittest.TestCase):
{
"pretraining_dataset": [
{
"path": "c4",
"path": "allenai/c4",
"name": "en",
"type": "pretrain",
}

View File

@@ -42,6 +42,19 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "USER:" not in res
assert "ASSISTANT:" not in res
def test_prompt_style_w_phi(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value)
res = next(prompter.build_prompt("tell me a joke about the following"))
assert (
"""<|system|>
Below is an instruction that describes a task. Write a response that appropriately completes the request.<|end|>
<|user|>
tell me a joke about the following<|end|>
<|assistant|>
"""
== res
)
def test_prompt_style_w_chat(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(