Compare commits

..

22 Commits

Author SHA1 Message Date
sunny
bfb80a3ef9 stuff 2024-10-30 13:44:06 -04:00
sunny
38773d661f fixing 2024-10-30 11:04:50 -04:00
sunny
271c2c2b82 fixed formatting 2024-10-29 15:50:56 -04:00
sunny
32b6f30947 fix attempt at issue 1991 2024-10-29 15:44:32 -04:00
sunny
fc1f275e6c yml change 2024-10-29 15:27:42 -04:00
sunny
46d2b4ce89 yml change 2024-10-29 15:25:25 -04:00
sunny
88c9a7aecc LOG for debug 2024-10-29 13:35:55 -04:00
sunny
d9a93990d1 yml 2024-10-29 10:40:32 -04:00
Oliver Kunc
107b67b852 Hardware requirements (#1997) [skip ci]
* Hardware requirements

https://github.com/axolotl-ai-cloud/axolotl/issues/1992

* Update README.md

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-29 10:13:50 -04:00
NanoCode012
bfc77b0f36 Feat: Add support for tokenizer’s or custom jinja chat_template (#1970)
* Allow using tokenizer's default chat template with fallbacks

Summary of changes:

1. Adds `tokenizer_default` as option for `chat_template` in
   `chat_template` prompt strategy that allows using the chat template
   from tokenizer's config.json
2. Allows falling back to chat templates available in axolotl if
   tokenizer does not have a chat template
3. Adds a mistral chat template which supports system message - taken
   from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja

---

Why?

Many popular models are not trained with chatml format. As a result for
the model to correctly learn chatml we have to turn on train_on_inputs
which requires more compute and time. If we can use the model's already
learned chat template we can just learn the output tokens

---

Todo:

- Write tests

* Add tests

* Fix lint and bug post merge from main

* Add option `chat_template_jinja` to provide a jinja template

* remove custom mistral template

* Address review comments and add docs

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* fix: set default to tokenizer template

* Merge branch 'main' into cj_tokenizer_default_prompt_template

* chore: remove redundant function

* fix: re-arrange enum declaration position

* fix: refactor artifact left from main merge

* feat(doc): updated config with chat template options and clarified examples

* chore: clarify doc

* chore: added example for non-default template

* chore: refactor

* fix: test

* fix: config being dropped and unittest to catch that

* chore: lint

* chore: skip duplicate

* fix: rename var after merge

* feat: add test for levy's dpo case

* fix: remove default setting on edge case where chat template overriden in dataset section

* feat: handle sharegpt deprecation better in docs

* feat: add example using fallback

* feat: handles chat_template requiring specific user/assistant order

* fix: update test based on new defaults

* fix: imported name incorrectly updated on merge

* chore: lint

* fix: update dummy message to prevent potential overlap with real content

* fix(doc): formatting

* fix: update bradleyterry to use new chat_template

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
2024-10-29 10:14:51 +07:00
Wing Lian
e1e0556c99 add option for resizing embeddings when adding new tokens (#2000)
* add option for resizing embeddings when adding new tokens

* let's just be opinonated about this setting and set it to False
2024-10-28 17:02:04 -04:00
Wing Lian
d3c45d27b5 fix zero3 (#1994) 2024-10-28 07:32:49 -04:00
NanoCode012
2501c1a6a3 Fix: Gradient Accumulation issue (#1980)
* feat: support new arg num_items_in_batch

* use kwargs to manage extra unknown kwargs for now

* upgrade against upstream transformers main

* make sure trl is on latest too

* fix for upgraded trl

* fix: handle trl and transformer signature change

* feat: update trl to handle transformer signature

* RewardDataCollatorWithPadding no longer has max_length

* handle updated signature for tokenizer vs processor class

* invert logic for tokenizer vs processor class

* processing_class, not processor class

* also handle processing class in dpo

* handle model name w model card creation

* upgrade transformers and add a loss check test

* fix install of tbparse requirements

* make sure to add tbparse to req

* feat: revert kwarg to positional kwarg to be explicit

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-25 11:28:23 -04:00
Mengqing Cao
1d6a5e2bd6 Refactor func load_model to class ModelLoader (#1909) 2024-10-25 09:06:56 -04:00
Wing Lian
718cfb2dd1 revert image tagged as main-latest (#1990) 2024-10-22 13:54:24 -04:00
Adam Hazell
9bd5f7d015 Log checkpoints as mlflow artifacts (#1976)
* Ensure hf_mlflow_log_artifact config var is set in env

* Add transformer MLflowCallback to callbacks list when mlflow enabled

* Test hf_mlflow_log_artifacts is set correctly

* Test mlflow not being used by default
2024-10-22 08:52:21 -04:00
Wing Lian
5c629ee444 use torch 2.4.1 images as latest now that torch 2.5.0 is out (#1987) 2024-10-21 19:51:06 -04:00
Wing Lian
955cca41fc don't explicitly set cpu pytorch version (#1986)
use a constraint file
use min version of xformers
don't install autoawq with pytorch 2.5.0
debugging for errors
upgrade pip first
fix action yml
add back try/except
retry w/o constraint
use --no-build-isolation
show torch version
install setuptools and wheel
add back try/except
2024-10-21 19:50:50 -04:00
Wing Lian
e12a2130e9 first pass at pytorch 2.5.0 support (#1982)
* first pass at pytorch 2.5.0 support

* attempt to install causal_conv1d with mamba

* gracefully handle missing xformers

* fix import

* fix incorrect version, add 2.5.0

* increase tests timeout
2024-10-21 11:00:45 -04:00
Wing Lian
67f744dc8c add pytorch 2.5.0 base images (#1979)
* add pytorch 2.5.0 base images

* make sure num examples for debug is zero and fix comparison
2024-10-18 03:36:51 -04:00
Sunny Liu
f62e23737b memoize dataset length for eval sample packing (#1974)
* wip on multimodal sample packing support

* wip on multimodal packing support

* llama-1b-yml

* setup logging for test

* yml

* yml

* yml

* fix for __len__ for eval sample packing

* reverted irrelavant changes

* reformatted, reverted log message

* reverted unnecessary changes

* added e2e multigpu testing for eval sample packing

* formatting

* fixed e2e test_eval params

* fix test_eval e2e multigpu

* fix test_eval e2e multigpu

* Update tests/e2e/multigpu/test_eval.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update tests/e2e/multigpu/test_eval.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-17 15:15:29 -04:00
Wing Lian
54673fd6ca also debug if other debug args are set (#1977) 2024-10-17 14:12:31 -04:00
55 changed files with 2865 additions and 1031 deletions

View File

@@ -36,6 +36,12 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3

View File

@@ -29,6 +29,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -86,6 +91,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -21,10 +21,17 @@ jobs:
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 121
cuda_version: 12.1.1
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.3.1
pytorch: 2.4.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"

View File

@@ -28,6 +28,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -85,6 +90,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
run: |
pip3 install wheel packaging
pip3 install -e .
pip3 install -r requirements-tests.txt
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name
id: tag

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
timeout-minutes: 20
steps:
@@ -47,13 +47,14 @@ jobs:
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-tests.txt
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
@@ -95,6 +96,13 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
timeout-minutes: 20
steps:
@@ -49,16 +49,20 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 show torch
pip3 install -U -e .
pip3 install -r requirements-tests.txt
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
@@ -72,7 +76,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
@@ -97,6 +101,12 @@ jobs:
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

295
1991.yml Normal file
View File

@@ -0,0 +1,295 @@
base_model: Qwen/Qwen2.5-14B-Instruct
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
chat_template: chatml
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
# input_layernorm layers
- model.layers.0.input_layernorm
- model.layers.1.input_layernorm
- model.layers.2.input_layernorm
- model.layers.3.input_layernorm
- model.layers.4.input_layernorm
- model.layers.5.input_layernorm
- model.layers.6.input_layernorm
- model.layers.7.input_layernorm
- model.layers.8.input_layernorm
- model.layers.9.input_layernorm
- model.layers.10.input_layernorm
- model.layers.11.input_layernorm
- model.layers.12.input_layernorm
- model.layers.13.input_layernorm
- model.layers.14.input_layernorm
- model.layers.15.input_layernorm
- model.layers.16.input_layernorm
- model.layers.17.input_layernorm
- model.layers.18.input_layernorm
- model.layers.19.input_layernorm
- model.layers.20.input_layernorm
- model.layers.21.input_layernorm
- model.layers.22.input_layernorm
- model.layers.23.input_layernorm
# lm_head layers
# mlp.down_proj layers
- model.layers.1.mlp.down_proj
- model.layers.35.mlp.down_proj
- model.layers.38.mlp.down_proj
- model.layers.37.mlp.down_proj
- model.layers.36.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.34.mlp.down_proj
- model.layers.44.mlp.down_proj
- model.layers.45.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.41.mlp.down_proj
- model.layers.33.mlp.down_proj
- model.layers.43.mlp.down_proj
- model.layers.40.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.39.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.31.mlp.down_proj
- model.layers.32.mlp.down_proj
# mlp.gate_proj layers
- model.layers.1.mlp.gate_proj
- model.layers.44.mlp.gate_proj
- model.layers.46.mlp.gate_proj
- model.layers.45.mlp.gate_proj
- model.layers.43.mlp.gate_proj
- model.layers.47.mlp.gate_proj
- model.layers.42.mlp.gate_proj
- model.layers.32.mlp.gate_proj
- model.layers.27.mlp.gate_proj
- model.layers.33.mlp.gate_proj
- model.layers.28.mlp.gate_proj
- model.layers.39.mlp.gate_proj
- model.layers.41.mlp.gate_proj
- model.layers.40.mlp.gate_proj
- model.layers.30.mlp.gate_proj
- model.layers.29.mlp.gate_proj
- model.layers.31.mlp.gate_proj
- model.layers.26.mlp.gate_proj
- model.layers.37.mlp.gate_proj
- model.layers.10.mlp.gate_proj
- model.layers.38.mlp.gate_proj
- model.layers.12.mlp.gate_proj
- model.layers.36.mlp.gate_proj
- model.layers.13.mlp.gate_proj
# mlp.up_proj layers
- model.layers.1.mlp.up_proj
- model.layers.13.mlp.up_proj
- model.layers.11.mlp.up_proj
- model.layers.14.mlp.up_proj
- model.layers.15.mlp.up_proj
- model.layers.12.mlp.up_proj
- model.layers.8.mlp.up_proj
- model.layers.16.mlp.up_proj
- model.layers.9.mlp.up_proj
- model.layers.19.mlp.up_proj
- model.layers.10.mlp.up_proj
- model.layers.7.mlp.up_proj
- model.layers.17.mlp.up_proj
- model.layers.20.mlp.up_proj
- model.layers.21.mlp.up_proj
- model.layers.18.mlp.up_proj
- model.layers.38.mlp.up_proj
- model.layers.37.mlp.up_proj
- model.layers.39.mlp.up_proj
- model.layers.42.mlp.up_proj
- model.layers.41.mlp.up_proj
- model.layers.27.mlp.up_proj
- model.layers.28.mlp.up_proj
- model.layers.34.mlp.up_proj
# model.norm layers
# post_attention_layernorm layers
- model.layers.0.post_attention_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.2.post_attention_layernorm
- model.layers.3.post_attention_layernorm
- model.layers.4.post_attention_layernorm
- model.layers.5.post_attention_layernorm
- model.layers.6.post_attention_layernorm
- model.layers.7.post_attention_layernorm
- model.layers.8.post_attention_layernorm
- model.layers.9.post_attention_layernorm
- model.layers.10.post_attention_layernorm
- model.layers.11.post_attention_layernorm
- model.layers.12.post_attention_layernorm
- model.layers.13.post_attention_layernorm
- model.layers.14.post_attention_layernorm
- model.layers.15.post_attention_layernorm
- model.layers.16.post_attention_layernorm
- model.layers.17.post_attention_layernorm
- model.layers.18.post_attention_layernorm
- model.layers.19.post_attention_layernorm
- model.layers.20.post_attention_layernorm
- model.layers.21.post_attention_layernorm
- model.layers.22.post_attention_layernorm
- model.layers.23.post_attention_layernorm
# self_attn.k_proj layers
- model.layers.47.self_attn.k_proj
- model.layers.39.self_attn.k_proj
- model.layers.41.self_attn.k_proj
- model.layers.37.self_attn.k_proj
- model.layers.35.self_attn.k_proj
- model.layers.44.self_attn.k_proj
- model.layers.38.self_attn.k_proj
- model.layers.14.self_attn.k_proj
- model.layers.7.self_attn.k_proj
- model.layers.12.self_attn.k_proj
- model.layers.11.self_attn.k_proj
- model.layers.32.self_attn.k_proj
- model.layers.10.self_attn.k_proj
- model.layers.8.self_attn.k_proj
- model.layers.9.self_attn.k_proj
- model.layers.6.self_attn.k_proj
- model.layers.45.self_attn.k_proj
- model.layers.42.self_attn.k_proj
- model.layers.5.self_attn.k_proj
- model.layers.40.self_attn.k_proj
- model.layers.33.self_attn.k_proj
- model.layers.0.self_attn.k_proj
- model.layers.34.self_attn.k_proj
- model.layers.13.self_attn.k_proj
# self_attn.o_proj layers
- model.layers.12.self_attn.o_proj
- model.layers.5.self_attn.o_proj
- model.layers.14.self_attn.o_proj
- model.layers.16.self_attn.o_proj
- model.layers.20.self_attn.o_proj
- model.layers.13.self_attn.o_proj
- model.layers.11.self_attn.o_proj
- model.layers.4.self_attn.o_proj
- model.layers.6.self_attn.o_proj
- model.layers.19.self_attn.o_proj
- model.layers.7.self_attn.o_proj
- model.layers.18.self_attn.o_proj
- model.layers.8.self_attn.o_proj
- model.layers.38.self_attn.o_proj
- model.layers.15.self_attn.o_proj
- model.layers.17.self_attn.o_proj
- model.layers.9.self_attn.o_proj
- model.layers.10.self_attn.o_proj
- model.layers.21.self_attn.o_proj
- model.layers.28.self_attn.o_proj
- model.layers.32.self_attn.o_proj
- model.layers.35.self_attn.o_proj
- model.layers.39.self_attn.o_proj
- model.layers.3.self_attn.o_proj
# self_attn.q_proj layers
- model.layers.1.self_attn.q_proj
- model.layers.2.self_attn.q_proj
- model.layers.3.self_attn.q_proj
- model.layers.44.self_attn.q_proj
- model.layers.29.self_attn.q_proj
- model.layers.45.self_attn.q_proj
- model.layers.43.self_attn.q_proj
- model.layers.32.self_attn.q_proj
- model.layers.38.self_attn.q_proj
- model.layers.19.self_attn.q_proj
- model.layers.42.self_attn.q_proj
- model.layers.34.self_attn.q_proj
- model.layers.36.self_attn.q_proj
- model.layers.40.self_attn.q_proj
- model.layers.26.self_attn.q_proj
- model.layers.20.self_attn.q_proj
- model.layers.39.self_attn.q_proj
- model.layers.28.self_attn.q_proj
- model.layers.35.self_attn.q_proj
- model.layers.41.self_attn.q_proj
- model.layers.33.self_attn.q_proj
- model.layers.25.self_attn.q_proj
- model.layers.30.self_attn.q_proj
- model.layers.27.self_attn.q_proj
# self_attn.v_proj layers
- model.layers.0.self_attn.v_proj
- model.layers.7.self_attn.v_proj
- model.layers.39.self_attn.v_proj
- model.layers.31.self_attn.v_proj
- model.layers.15.self_attn.v_proj
- model.layers.10.self_attn.v_proj
- model.layers.32.self_attn.v_proj
- model.layers.41.self_attn.v_proj
- model.layers.6.self_attn.v_proj
- model.layers.33.self_attn.v_proj
- model.layers.42.self_attn.v_proj
- model.layers.29.self_attn.v_proj
- model.layers.14.self_attn.v_proj
- model.layers.9.self_attn.v_proj
- model.layers.35.self_attn.v_proj
- model.layers.38.self_attn.v_proj
- model.layers.13.self_attn.v_proj
- model.layers.30.self_attn.v_proj
- model.layers.5.self_attn.v_proj
- model.layers.34.self_attn.v_proj
- model.layers.28.self_attn.v_proj
- model.layers.37.self_attn.v_proj
- model.layers.27.self_attn.v_proj
- model.layers.11.self_attn.v_proj
# model.embed_tokens layers
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: linear
learning_rate: 5e-6
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
gradient_checkpointing: unsloth
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
save_total_limit: 4
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.05
special_tokens:
eos_token: <|im_end|>

View File

@@ -121,7 +121,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template)
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt

View File

@@ -23,11 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -37,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi
# So we can test the Docker image
RUN pip install -r requirements-tests.txt
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,6 +1,6 @@
#!/bin/bash
set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)

View File

@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
timeout=60 * 60,
cpu=8.0,
memory=131072,
)

View File

@@ -14,15 +14,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -24,15 +24,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,15 +20,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,7 +20,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -83,7 +83,7 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
@@ -124,6 +124,48 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
@@ -142,9 +184,16 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so

View File

@@ -6,6 +6,8 @@ order: 3
## sharegpt
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"}
@@ -69,3 +71,138 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]}
```
See `config.qmd` for full configs and supported templates.
### Migrating from sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```
The configuration would look like:
```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -1,50 +0,0 @@
base_model: meta-llama/Llama-3.1-8B-Instruct
save_safetensors: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: ./last_run_prepared
output_dir: ./outputs/fft-out
sequence_len: 8192
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
learning_rate: 2e-5
bf16: auto
fp16:
tf32: false
logging_steps: 2
xformers_attention:
flash_attention: true
warmup_steps: 2
evals_per_epoch: 2
save_steps: 2
max_steps: 2
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: true
fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -0,0 +1,77 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -2,3 +2,4 @@ pre-commit
black
mypy
types-requests
tbparse

View File

@@ -1,12 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.45.2
transformers==4.46.0
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.14.4
deepspeed==0.15.3
pydantic==2.6.3
addict
fire
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece
wandb
einops
xformers==0.0.28.post1
xformers>=0.0.23.post1
optimum==1.16.2
hf_transfer
colorama
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl==0.9.6
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
zstandard==0.22.0
fastcore

View File

@@ -31,6 +31,8 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
@@ -50,10 +52,16 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 4):
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
@@ -73,7 +81,6 @@ def parse_requirements():
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
@@ -102,6 +109,7 @@ setup(
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",

View File

@@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -462,7 +462,12 @@ def load_datasets(
processor=processor,
)
if cli_args.debug or cfg.debug:
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(

View File

@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)

View File

@@ -7,6 +7,7 @@ import abc
import gc
import importlib
import importlib.util
import inspect
import logging
import math
import os
@@ -27,7 +28,6 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -63,7 +63,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -666,7 +666,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
@@ -674,8 +676,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)
return self.orpo_compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -771,7 +783,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(self, model, inputs, return_outputs=False):
def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
@@ -877,13 +895,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
def _save_checkpoint(self, model, trial):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
return super()._save_checkpoint(model, trial)
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -898,6 +916,7 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
@@ -1005,18 +1024,32 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
self,
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
res = super().tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs)
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
@@ -1119,12 +1152,17 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -1556,7 +1594,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
)
@@ -1662,12 +1700,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(),
**trainer_kwargs,
@@ -1708,6 +1751,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
@@ -1910,7 +1955,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
@@ -1922,11 +1967,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)

View File

@@ -22,7 +22,6 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -44,7 +43,19 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_available() -> bool:
try:
import xformers # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
def is_xformers_swiglu_available() -> bool:
if not is_xformers_available():
return False
from xformers.ops.common import get_xformers_operator
try:
@@ -57,6 +68,11 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
@@ -181,49 +197,6 @@ class FusedAttention(LlamaAttention):
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(

View File

@@ -27,15 +27,18 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
elif hasattr(transformers, "modeling_flash_attention_utils"):
if not has_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

View File

@@ -16,26 +16,6 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -80,12 +60,6 @@ def get_forward_code() -> str:
return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -98,48 +72,31 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama":
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
from transformers.loss import loss_utils
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
else:
raise ValueError("Unsupported model type")

View File

@@ -0,0 +1,51 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import set_module_name
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)

View File

@@ -6,7 +6,7 @@ import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,13 +2,18 @@
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -27,18 +32,24 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
return {
@@ -53,15 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
@@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = BTChatTemplateStrategy(

View File

@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -405,10 +405,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),

View File

@@ -2,15 +2,16 @@
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -30,6 +31,12 @@ def default(
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages]
messages = [
{
@@ -46,28 +53,29 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[chosen],
[dummy_user_message, chosen],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
[dummy_user_message, rejected],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
class Message(BaseModel):
@@ -28,18 +28,13 @@ def load(
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
tokenizer.chat_template = chat_template_string
return ORPOTokenizingStrategy(
ORPOPrompter(chat_template, tokenizer),
ORPOPrompter(chat_template_string, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -248,28 +243,30 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]

View File

@@ -62,7 +62,7 @@ def build_loader(
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
)
conversation = (
ds_cfg["conversation"]

View File

@@ -260,8 +260,10 @@ def train(
if not cfg.hub_model_id:
try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -2,8 +2,19 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional
CHAT_TEMPLATES = {
if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
LOG = logging.getLogger("axolotl.utils.chat_templates")
_JINJA_TEMPALTE_CHOICE = "jinja"
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
_CHAT_TEMPLATES = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
@@ -21,12 +32,18 @@ CHAT_TEMPLATES = {
}
def chat_templates(user_choice: str):
def get_chat_template(
user_choice: str,
jinja_template: Optional[str] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
):
"""
Finds the correct chat_template for the tokenizer_config.
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
Args:
user_choice (str): The user's choice of template.
jinja_template (Optional[str], optional): The jinja template string. Defaults to None.
tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None.
Returns:
str: The chosen template string.
@@ -34,13 +51,71 @@ def chat_templates(user_choice: str):
Raises:
ValueError: If the user_choice is not found in the templates.
"""
if user_choice == _JINJA_TEMPALTE_CHOICE:
if not jinja_template:
raise ValueError(
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}"
)
return jinja_template
if user_choice in CHAT_TEMPLATES:
return CHAT_TEMPLATES[user_choice]
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
)
if not tokenizer.chat_template:
raise ValueError(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
)
if user_choice in _CHAT_TEMPLATES:
return _CHAT_TEMPLATES[user_choice]
raise ValueError(f"Template '{user_choice}' not found.")
def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
if ds_cfg and ds_cfg.get("chat_template"):
chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = ds_cfg.get("chat_template_jinja")
else:
chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = cfg.get("chat_template_jinja")
return chat_template_choice, chat_template_jinja
def get_chat_template_from_config(
cfg,
ds_cfg: Optional[Dict[str, Any]] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
return get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
def register_chat_template(template_name: str, chat_template: str):
"""
Registers chat templates.
@@ -50,7 +125,7 @@ def register_chat_template(template_name: str, chat_template: str):
chat_template (str): The template string.
"""
if template_name in CHAT_TEMPLATES:
if template_name in _CHAT_TEMPLATES:
raise ValueError(f"Template '{template_name}' already exists.")
CHAT_TEMPLATES[template_name] = chat_template
_CHAT_TEMPLATES[template_name] = chat_template

View File

@@ -228,6 +228,7 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].chat_template = cfg.chat_template
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):

View File

@@ -8,9 +8,16 @@ import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
@@ -21,6 +28,37 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -105,13 +143,19 @@ class SFTDataset(BaseModel):
input_transform: Optional[str] = None
shards: Optional[int] = None
conversation: Optional[str] = None
chat_template: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
Union[
ChatTemplate,
str,
]
] = None
chat_template_jinja: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -122,13 +166,32 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
@@ -174,35 +237,6 @@ class KTODataset(BaseModel):
revision: Optional[str] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -549,6 +583,7 @@ class AxolotlInputConfig(
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
@@ -718,7 +753,13 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[ChatTemplate] = None
chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
@@ -827,6 +868,23 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):

View File

@@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault):
# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True
# Enable logging hf artifacts in mlflow if value is truthy
if cfg.hf_mlflow_log_artifacts is True:
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"

File diff suppressed because it is too large Load Diff

View File

@@ -133,6 +133,8 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0
self.eff_total_slots = 0
self.len_across_ranks = None
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -195,15 +197,14 @@ class MultipackBatchSampler(BatchSampler):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)
if not self.len_across_ranks:
len_batches = self.num_batches()
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
def _len_est(self):
efficiency = (

View File

@@ -0,0 +1,155 @@
"""
E2E tests for multigpu eval
"""
import logging
import os
import unittest
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
class TestMultiGPUEval(unittest.TestCase):
"""
Test case for MultiGPU Eval Sample Packing
"""
@with_temp_dir
def test_eval_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -1,22 +1,12 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.unsloth_ import (
check_cel_is_patchable,
check_self_attn_is_patchable,
)
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_cel_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_cel_is_patchable(),
"HF transformers loss code has changed and isn't patchable",
)
def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(

View File

@@ -0,0 +1,95 @@
"""Module for testing ModelLoader."""
import shutil
import tempfile
import pytest
import torch
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
def fixture_temp_dir():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
class TestLoadModelUtils:
"""
Testing module testing ModelLoader.
"""
def setup_method(self):
# load config
self.cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"tokenizer_config": "JackFram/llama-68m",
"sequence_len": 1024,
"load_in_8bit": False,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer="",
)
)
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
@pytest.mark.parametrize(
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
def test_convert_embedding_modules_dtype(
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():
if (
"norm" in name
or (before_kbit_train_or_finetune and name.endswith(".gate"))
or (
any(m in name for m in embedding_modules)
and hasattr(module, "weight")
)
):
for _, param in module.named_parameters():
assert param.dtype == dist_dtype

View File

@@ -0,0 +1,74 @@
"""
E2E tests for packed training
"""
import logging
import os
import unittest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase):
"""
Test case for Packed training of llama models
"""
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -0,0 +1,125 @@
"""
Tests for utils in axolotl.utils.chat_templates
"""
import unittest
import pytest
from transformers import AutoTokenizer
from axolotl.utils.chat_templates import (
_CHAT_TEMPLATES,
extract_chat_template_args,
get_chat_template,
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
return tokenizer
class TestGetChatTemplateUtils:
"""
Tests the get_chat_template function.
"""
def test_known_chat_template(self):
chat_template_str = get_chat_template("llama3")
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
def test_invalid_chat_template(self):
with pytest.raises(ValueError) as exc:
get_chat_template("invalid_template")
assert str(exc) == "Template 'invalid_template' not found."
def test_tokenizer_default_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=None)
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_tokenizer_default_fallback_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
self, llama3_tokenizer
):
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == get_chat_template("chatml")
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
self, llama3_tokenizer
):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_jinja_template_mode(self):
jinja_template = "example_jinja_template"
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
assert chat_template_str == jinja_template
def test_jinja_template_mode_no_jinja_template(self):
with pytest.raises(ValueError):
get_chat_template("jinja", jinja_template=None)
def test_extract_chat_template_args(self):
# No ds_cfg
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml"},
)
assert chat_template_choice == "chatml"
assert chat_template_jinja is None
# ds_cfg provided
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
)
assert chat_template_choice == "llama3"
assert chat_template_jinja is None
# ds_cfg provided with jinja template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml", "chat_template_jinja": None},
ds_cfg={
"chat_template": "jinja",
"chat_template_jinja": "ds_jinja_template",
},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "ds_jinja_template"
# ds_cfg provided with no chat_template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "global_jinja_template"
if __name__ == "__main__":
unittest.main()

View File

@@ -11,7 +11,7 @@ from axolotl.prompt_strategies.chat_template import (
load,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
roles={
@@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
phi35_tokenizer,
chat_template=chat_templates("phi_35"),
chat_template=get_chat_template("phi_35"),
message_field_role="role",
message_field_content="content",
roles={
@@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
@@ -230,7 +230,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -283,7 +283,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -336,7 +336,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,

View File

@@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@@ -35,7 +35,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
@@ -80,7 +80,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -123,7 +123,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -151,7 +151,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
@@ -184,7 +184,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -205,7 +205,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -232,7 +232,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -282,7 +282,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -315,7 +315,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -343,7 +343,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
drop_system_message=True,
),
tokenizer=llama3_tokenizer,
@@ -371,7 +371,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
roles=custom_roles,
),
tokenizer=llama3_tokenizer,
@@ -424,7 +424,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),

View File

@@ -86,6 +86,20 @@ def fixture_llama3_tokenizer():
return tokenizer
@pytest.fixture(name="phi3_tokenizer")
def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
return tokenizer
@pytest.fixture(name="gemma_tokenizer")
def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
return tokenizer
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -99,7 +113,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
"type": "chat_template",
}
],
}
@@ -124,7 +138,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
"type": "chat_template",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
@@ -152,5 +166,65 @@ class TestAssistantDPOChatTemplateLlama3:
assert result["rejected"] == "party on<|eot_id|>"
class TestAssistantDPOChatTemplatePhi3:
"""
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == (
"<|user|>\nhello<|end|>\n"
+ "<|assistant|>\nhello<|end|>\n"
+ "<|user|>\ngoodbye<|end|>\n"
+ "<|assistant|>\n"
)
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
class TestAssistantDPOChatTemplateGemma:
"""
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
assert result["prompt"] == (
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
+ "<start_of_turn>model\nhello<end_of_turn>\n"
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
+ "<start_of_turn>model\n"
)
assert result["chosen"] == "goodbye<end_of_turn>"
assert result["rejected"] == "party on<end_of_turn>"
if __name__ == "__main__":
unittest.main()

View File

@@ -13,6 +13,7 @@ from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -1432,3 +1433,58 @@ class TestValidationComet(BaseValidation):
for key in comet_env.keys():
os.environ.pop(key, None)
class TestValidationMLflow(BaseValidation):
"""
Validation test for MLflow
"""
def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):
cfg = (
DictDefault(
{
"hf_mlflow_log_artifacts": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.hf_mlflow_log_artifacts is True
# Check it's not already present in env
assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ
setup_mlflow_env_vars(new_cfg)
assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true"
os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None)
def test_mlflow_not_used_by_default(self, minimal_cfg):
cfg = DictDefault({}) | minimal_cfg
new_cfg = validate_config(cfg)
setup_mlflow_env_vars(new_cfg)
assert cfg.use_mlflow is not True
cfg = (
DictDefault(
{
"mlflow_experiment_name": "foo",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
setup_mlflow_env_vars(new_cfg)
assert new_cfg.use_mlflow is True
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)

View File

@@ -0,0 +1,238 @@
"""Module for testing the validation module for the dataset config"""
import warnings
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
from axolotl.utils.dict import DictDefault
warnings.filterwarnings("error")
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
# pylint: disable=too-many-public-methods (duplicate-code)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
class TestValidationCheckDatasetConfig(BaseValidation):
"""
Test the validation for the dataset config to ensure no correct parameters are dropped
"""
def test_dataset_config_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "sharegpt",
"conversation": "chatml",
"shards": 10,
}
]
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template is None
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == ChatTemplate.chatml
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"chat_template": "gemma",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == cfg.chat_template
assert (
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()

View File

@@ -1,18 +1,64 @@
"""Module for testing models utils file."""
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_torch_mps_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model
from axolotl.utils.models import ModelLoader, load_model
class ModelsUtilsTest(unittest.TestCase):
class TestModelsUtils:
"""Testing module for models utils."""
def setup_method(self) -> None:
# load config
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
{
"base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"load_in_8bit": True,
"load_in_4bit": False,
"adapter": "lora",
"flash_attention": False,
"sample_packing": True,
"device_map": "auto",
}
)
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
spec=PreTrainedTokenizerBase
)
self.inference = False # pylint: disable=attribute-defined-outside-init
self.reference_model = True # pylint: disable=attribute-defined-outside-init
# init ModelLoader
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
)
def test_set_device_map_config(self):
# check device_map
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
self.model_loader.set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
else:
assert device_map in self.model_loader.model_kwargs["device_map"]
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
@@ -35,3 +81,38 @@ class ModelsUtilsTest(unittest.TestCase):
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@pytest.mark.parametrize("gptq", [True, False])
def test_set_quantization_config(
self,
adapter,
load_in_8bit,
load_in_4bit,
gptq,
):
# init cfg as args
self.cfg.load_in_8bit = load_in_8bit
self.cfg.load_in_4bit = load_in_4bit
self.cfg.gptq = gptq
self.cfg.adapter = adapter
self.model_loader.set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
)
elif load_in_8bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_8bit"]
elif load_in_4bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_4bit"]
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
self.cfg.adapter == "lora" and load_in_8bit
):
assert self.model_loader.model_kwargs.get(
"quantization_config", BitsAndBytesConfig
)