Compare commits

..

16 Commits

Author SHA1 Message Date
sunny
d1bf20f990 test 2024-11-01 17:15:32 -04:00
sunny
bb648cbc63 test 2024-11-01 17:02:35 -04:00
sunny
8b0bca4842 test 2024-11-01 17:01:03 -04:00
sunny
d36baf44b1 test 2024-11-01 17:00:35 -04:00
sunny
16c8140d20 test 2024-11-01 16:38:09 -04:00
sunny
21c25cf7bc test 2024-11-01 16:34:00 -04:00
sunny
32288a5d3c test 2024-11-01 16:23:01 -04:00
Wing Lian
052a9a79b4 only run the remainder of the gpu test suite if one case passes first (#2009) [skip ci]
* only run the remainder of the gpu test suite if one case passes first

* also reduce the test matrix
2024-10-31 13:45:01 -04:00
Wing Lian
3591bcfaf9 add torch 2.5.1 for base image (#2010) 2024-10-31 13:27:49 -04:00
Wing Lian
dc1de7d81b add retries for load datasets requests failures (#2007) 2024-10-31 13:26:14 -04:00
Chirag Jain
d4dbfa02fe Add plugin manager's callback hooks to training flow (#2006)
* Add plugin manager's callback hooks to training flow

* Use .values() instead of .items()
2024-10-31 12:13:46 -04:00
NanoCode012
5c7e89105d Fix: modelloader handling of model_kwargs load_in*bit (#1999)
* fix: load_in_*bit not properly read

* fix: load_*bit check

* fix: typo

* refactor: load * bit handling

* feat: add test dpo lora multi-gpu

* fix: turn off sample packing for dpo

* fix: missing warmup_steps

* fix: test to load in 8bit for lora

* skip 8bit lora on h100, add 4bit lora on h100 to multi gpu tests

* chore: reduce max_steps

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-30 14:41:34 -04:00
Chirag Jain
74db2a1bae Fix get_chat_template call for trainer builder (#2003) 2024-10-30 14:27:00 -04:00
Geun, Lim
e62554c419 feat: add Exaone3 chat_template (#1995) 2024-10-30 12:30:12 -04:00
Wing Lian
32c60765ef remove skipped test (#2002)
* remove skipped test

* use mean_resizing_embeddings with qlora and added tokens

* use </s> as pad_token to prevent resize of embeddings

* make sure local hub test saves to a tmp dir

* use Path so concatenation works

* make sure to use tmp_ds_path for data files
2024-10-30 12:27:04 -04:00
NanoCode012
8c3a727f9d feat: update yml chat_template to specify dataset field (#2001) [skip ci]
* feat: update yml chat_template to specify dataset field

* feat: replace sharegpt references with chat_template
2024-10-29 10:26:03 -04:00
28 changed files with 425 additions and 580 deletions

View File

@@ -40,7 +40,7 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -82,13 +82,6 @@ jobs:
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -72,13 +72,53 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
docker-e2e-tests-1st:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
fail-fast: false
matrix:
@@ -89,18 +129,6 @@ jobs:
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

295
1991.yml
View File

@@ -1,295 +0,0 @@
base_model: Qwen/Qwen2.5-14B-Instruct
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
chat_template: chatml
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
# input_layernorm layers
- model.layers.0.input_layernorm
- model.layers.1.input_layernorm
- model.layers.2.input_layernorm
- model.layers.3.input_layernorm
- model.layers.4.input_layernorm
- model.layers.5.input_layernorm
- model.layers.6.input_layernorm
- model.layers.7.input_layernorm
- model.layers.8.input_layernorm
- model.layers.9.input_layernorm
- model.layers.10.input_layernorm
- model.layers.11.input_layernorm
- model.layers.12.input_layernorm
- model.layers.13.input_layernorm
- model.layers.14.input_layernorm
- model.layers.15.input_layernorm
- model.layers.16.input_layernorm
- model.layers.17.input_layernorm
- model.layers.18.input_layernorm
- model.layers.19.input_layernorm
- model.layers.20.input_layernorm
- model.layers.21.input_layernorm
- model.layers.22.input_layernorm
- model.layers.23.input_layernorm
# lm_head layers
# mlp.down_proj layers
- model.layers.1.mlp.down_proj
- model.layers.35.mlp.down_proj
- model.layers.38.mlp.down_proj
- model.layers.37.mlp.down_proj
- model.layers.36.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.34.mlp.down_proj
- model.layers.44.mlp.down_proj
- model.layers.45.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.41.mlp.down_proj
- model.layers.33.mlp.down_proj
- model.layers.43.mlp.down_proj
- model.layers.40.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.39.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.31.mlp.down_proj
- model.layers.32.mlp.down_proj
# mlp.gate_proj layers
- model.layers.1.mlp.gate_proj
- model.layers.44.mlp.gate_proj
- model.layers.46.mlp.gate_proj
- model.layers.45.mlp.gate_proj
- model.layers.43.mlp.gate_proj
- model.layers.47.mlp.gate_proj
- model.layers.42.mlp.gate_proj
- model.layers.32.mlp.gate_proj
- model.layers.27.mlp.gate_proj
- model.layers.33.mlp.gate_proj
- model.layers.28.mlp.gate_proj
- model.layers.39.mlp.gate_proj
- model.layers.41.mlp.gate_proj
- model.layers.40.mlp.gate_proj
- model.layers.30.mlp.gate_proj
- model.layers.29.mlp.gate_proj
- model.layers.31.mlp.gate_proj
- model.layers.26.mlp.gate_proj
- model.layers.37.mlp.gate_proj
- model.layers.10.mlp.gate_proj
- model.layers.38.mlp.gate_proj
- model.layers.12.mlp.gate_proj
- model.layers.36.mlp.gate_proj
- model.layers.13.mlp.gate_proj
# mlp.up_proj layers
- model.layers.1.mlp.up_proj
- model.layers.13.mlp.up_proj
- model.layers.11.mlp.up_proj
- model.layers.14.mlp.up_proj
- model.layers.15.mlp.up_proj
- model.layers.12.mlp.up_proj
- model.layers.8.mlp.up_proj
- model.layers.16.mlp.up_proj
- model.layers.9.mlp.up_proj
- model.layers.19.mlp.up_proj
- model.layers.10.mlp.up_proj
- model.layers.7.mlp.up_proj
- model.layers.17.mlp.up_proj
- model.layers.20.mlp.up_proj
- model.layers.21.mlp.up_proj
- model.layers.18.mlp.up_proj
- model.layers.38.mlp.up_proj
- model.layers.37.mlp.up_proj
- model.layers.39.mlp.up_proj
- model.layers.42.mlp.up_proj
- model.layers.41.mlp.up_proj
- model.layers.27.mlp.up_proj
- model.layers.28.mlp.up_proj
- model.layers.34.mlp.up_proj
# model.norm layers
# post_attention_layernorm layers
- model.layers.0.post_attention_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.2.post_attention_layernorm
- model.layers.3.post_attention_layernorm
- model.layers.4.post_attention_layernorm
- model.layers.5.post_attention_layernorm
- model.layers.6.post_attention_layernorm
- model.layers.7.post_attention_layernorm
- model.layers.8.post_attention_layernorm
- model.layers.9.post_attention_layernorm
- model.layers.10.post_attention_layernorm
- model.layers.11.post_attention_layernorm
- model.layers.12.post_attention_layernorm
- model.layers.13.post_attention_layernorm
- model.layers.14.post_attention_layernorm
- model.layers.15.post_attention_layernorm
- model.layers.16.post_attention_layernorm
- model.layers.17.post_attention_layernorm
- model.layers.18.post_attention_layernorm
- model.layers.19.post_attention_layernorm
- model.layers.20.post_attention_layernorm
- model.layers.21.post_attention_layernorm
- model.layers.22.post_attention_layernorm
- model.layers.23.post_attention_layernorm
# self_attn.k_proj layers
- model.layers.47.self_attn.k_proj
- model.layers.39.self_attn.k_proj
- model.layers.41.self_attn.k_proj
- model.layers.37.self_attn.k_proj
- model.layers.35.self_attn.k_proj
- model.layers.44.self_attn.k_proj
- model.layers.38.self_attn.k_proj
- model.layers.14.self_attn.k_proj
- model.layers.7.self_attn.k_proj
- model.layers.12.self_attn.k_proj
- model.layers.11.self_attn.k_proj
- model.layers.32.self_attn.k_proj
- model.layers.10.self_attn.k_proj
- model.layers.8.self_attn.k_proj
- model.layers.9.self_attn.k_proj
- model.layers.6.self_attn.k_proj
- model.layers.45.self_attn.k_proj
- model.layers.42.self_attn.k_proj
- model.layers.5.self_attn.k_proj
- model.layers.40.self_attn.k_proj
- model.layers.33.self_attn.k_proj
- model.layers.0.self_attn.k_proj
- model.layers.34.self_attn.k_proj
- model.layers.13.self_attn.k_proj
# self_attn.o_proj layers
- model.layers.12.self_attn.o_proj
- model.layers.5.self_attn.o_proj
- model.layers.14.self_attn.o_proj
- model.layers.16.self_attn.o_proj
- model.layers.20.self_attn.o_proj
- model.layers.13.self_attn.o_proj
- model.layers.11.self_attn.o_proj
- model.layers.4.self_attn.o_proj
- model.layers.6.self_attn.o_proj
- model.layers.19.self_attn.o_proj
- model.layers.7.self_attn.o_proj
- model.layers.18.self_attn.o_proj
- model.layers.8.self_attn.o_proj
- model.layers.38.self_attn.o_proj
- model.layers.15.self_attn.o_proj
- model.layers.17.self_attn.o_proj
- model.layers.9.self_attn.o_proj
- model.layers.10.self_attn.o_proj
- model.layers.21.self_attn.o_proj
- model.layers.28.self_attn.o_proj
- model.layers.32.self_attn.o_proj
- model.layers.35.self_attn.o_proj
- model.layers.39.self_attn.o_proj
- model.layers.3.self_attn.o_proj
# self_attn.q_proj layers
- model.layers.1.self_attn.q_proj
- model.layers.2.self_attn.q_proj
- model.layers.3.self_attn.q_proj
- model.layers.44.self_attn.q_proj
- model.layers.29.self_attn.q_proj
- model.layers.45.self_attn.q_proj
- model.layers.43.self_attn.q_proj
- model.layers.32.self_attn.q_proj
- model.layers.38.self_attn.q_proj
- model.layers.19.self_attn.q_proj
- model.layers.42.self_attn.q_proj
- model.layers.34.self_attn.q_proj
- model.layers.36.self_attn.q_proj
- model.layers.40.self_attn.q_proj
- model.layers.26.self_attn.q_proj
- model.layers.20.self_attn.q_proj
- model.layers.39.self_attn.q_proj
- model.layers.28.self_attn.q_proj
- model.layers.35.self_attn.q_proj
- model.layers.41.self_attn.q_proj
- model.layers.33.self_attn.q_proj
- model.layers.25.self_attn.q_proj
- model.layers.30.self_attn.q_proj
- model.layers.27.self_attn.q_proj
# self_attn.v_proj layers
- model.layers.0.self_attn.v_proj
- model.layers.7.self_attn.v_proj
- model.layers.39.self_attn.v_proj
- model.layers.31.self_attn.v_proj
- model.layers.15.self_attn.v_proj
- model.layers.10.self_attn.v_proj
- model.layers.32.self_attn.v_proj
- model.layers.41.self_attn.v_proj
- model.layers.6.self_attn.v_proj
- model.layers.33.self_attn.v_proj
- model.layers.42.self_attn.v_proj
- model.layers.29.self_attn.v_proj
- model.layers.14.self_attn.v_proj
- model.layers.9.self_attn.v_proj
- model.layers.35.self_attn.v_proj
- model.layers.38.self_attn.v_proj
- model.layers.13.self_attn.v_proj
- model.layers.30.self_attn.v_proj
- model.layers.5.self_attn.v_proj
- model.layers.34.self_attn.v_proj
- model.layers.28.self_attn.v_proj
- model.layers.37.self_attn.v_proj
- model.layers.27.self_attn.v_proj
- model.layers.11.self_attn.v_proj
# model.embed_tokens layers
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: linear
learning_rate: 5e-6
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
gradient_checkpointing: unsloth
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
save_total_limit: 4
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.05
special_tokens:
eos_token: <|im_end|>

View File

@@ -121,7 +121,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# fastchat conversation (deprecation soon, use chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
@@ -562,7 +562,8 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```

View File

@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false
datasets:
- path: philschmid/guanaco-sharegpt-style
type: sharegpt
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
shards: 10
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model

View File

@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
```yaml
datasets:
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
type: sharegpt
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
type: chat_template
```
>[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
// .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0",
"configurations": [
{
"name": "Debug axolotl prompt - sharegpt",
"name": "Debug axolotl prompt - chat_template",
"type": "python",
"module": "accelerate.commands.launch",
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
"-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div>
<br>
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -16,7 +16,10 @@ chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -11,8 +11,11 @@ chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -4,11 +4,15 @@ tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
chat_template: jamba
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft

View File

@@ -4,28 +4,26 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
strict: false
chat_template: llama3
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_project: check_liger_hf_GA_llama_fix-3
wandb_entity: axolotl-ai
wandb_watch:
wandb_name:
wandb_name: pr/fix333-tr4.46.1
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -10,7 +10,6 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.0
transformers==4.46.1
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
@@ -33,8 +33,8 @@ gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.3.0
triton>=3.1.0
liger-kernel==0.3.1
mamba-ssm==1.2.0.post1

View File

@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -48,6 +48,7 @@ from trl import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -895,13 +896,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial):
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial)
return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -1147,6 +1148,12 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -1173,11 +1180,17 @@ class TrainerBuilderBase(abc.ABC):
return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
@@ -1223,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1595,7 +1608,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
self.cfg.chat_template,
tokenizer=self.tokenizer,
)
if self.cfg.rl == "orpo":
@@ -1790,7 +1804,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build_training_arguments(self, total_num_steps):
@@ -1999,11 +2013,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):

View File

@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import List
from typing import OrderedDict
class BasePlugin:
@@ -47,7 +48,7 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg):
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
@@ -63,7 +64,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg):
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
@@ -74,7 +75,7 @@ class BasePlugin:
None
"""
def post_model_load(self, cfg, model):
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -86,7 +87,7 @@ class BasePlugin:
None
"""
def pre_lora_load(self, cfg, model):
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
@@ -98,7 +99,7 @@ class BasePlugin:
None
"""
def post_lora_load(self, cfg, model):
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
@@ -110,7 +111,7 @@ class BasePlugin:
None
"""
def create_optimizer(self, cfg, trainer):
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
@@ -122,7 +123,9 @@ class BasePlugin:
object: The created optimizer.
"""
def create_lr_scheduler(self, cfg, trainer, optimizer):
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
@@ -135,7 +138,7 @@ class BasePlugin:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model):
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer before training.
@@ -146,8 +149,11 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def add_callbacks_post_trainer(self, cfg, trainer):
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after training.
@@ -158,8 +164,9 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def post_train(self, cfg, model):
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
@@ -171,7 +178,7 @@ class BasePlugin:
None
"""
def post_train_unload(self, cfg):
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
@@ -227,7 +234,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: List[BasePlugin] = []
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
@@ -237,7 +244,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
cls._instance.plugins = collections.OrderedDict()
return cls._instance
@staticmethod
@@ -265,7 +272,7 @@ class PluginManager:
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
self.plugins[plugin_name] = plugin
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -277,7 +284,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
for plugin in self.plugins.values():
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
@@ -293,7 +300,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
@@ -307,7 +314,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
@@ -321,7 +328,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
@@ -335,7 +342,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
@@ -349,7 +356,7 @@ class PluginManager:
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
@@ -367,7 +374,7 @@ class PluginManager:
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
@@ -385,7 +392,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
@@ -401,7 +408,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
@@ -416,5 +423,5 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)

View File

@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin):
"""
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,9 +15,12 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
class LigerArgs(BaseModel):
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

@@ -27,18 +27,15 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
elif hasattr(transformers, "modeling_flash_attention_utils"):
if not has_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

File diff suppressed because one or more lines are too long

View File

@@ -57,6 +57,7 @@ class ChatTemplate(str, Enum):
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):

View File

@@ -2,9 +2,11 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -53,6 +55,28 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:

View File

@@ -394,15 +394,10 @@ class ModelLoader:
and self.cfg.flash_attention
and self.cfg.sample_packing
):
has_remote_code = (
"auto_map" in self.model_config
and self.model_type in self.model_config["auto_map"]
)
patch_for_multipack(
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
is_remote_code=self.cfg.trust_remote_code,
)
if self.cfg.is_llama_derived_model:
@@ -645,9 +640,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
):
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -670,9 +663,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
):
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
@@ -685,10 +676,8 @@ class ModelLoader:
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
if "load_in_8bit" in self.model_kwargs:
del self.model_kwargs["load_in_8bit"]
if "load_in_4bit" in self.model_kwargs:
del self.model_kwargs["load_in_4bit"]
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
def set_attention_config(self) -> None:
"""
@@ -973,17 +962,10 @@ class ModelLoader:
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (is_load_in_8bit or is_load_in_4bit)
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
@@ -1121,16 +1103,10 @@ class ModelLoader:
# ---------------------------------------------------------
# put model to accelerator
# ---------------------------------------------------------
is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)
if (
self.cfg.ddp
and not is_load_in_8bit
and not (self.cfg.rl and is_load_in_4bit)
and not self.cfg.load_in_8bit
and not (self.cfg.rl and self.cfg.load_in_4bit)
and not skip_move_to_device
):
# TODO revaldate this conditional

View File

@@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import is_hopper, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -59,7 +59,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -116,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 50,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -144,6 +144,146 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
@with_temp_dir
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp(self, temp_dir):
# pylint: disable=duplicate-code
@@ -165,7 +305,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -231,7 +371,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -273,7 +413,6 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skip("disabled due to upstream issue")
@with_temp_dir
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
@@ -282,6 +421,7 @@ class TestMultiGPULlama(unittest.TestCase):
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
"tokenizer_type": "AutoTokenizer",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
@@ -297,7 +437,7 @@ class TestMultiGPULlama(unittest.TestCase):
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|end_of_text|>",
"pad_token": "</s>",
},
"datasets": [
{
@@ -307,7 +447,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -373,7 +513,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -432,7 +572,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,

View File

@@ -47,7 +47,7 @@ class TestMultiGPUQwen2(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"warmup_steps": 20,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_1_1, with_temp_dir
from ..utils import require_torch_2_3_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_1_1
@require_torch_2_3_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -9,6 +9,8 @@ from functools import wraps
from importlib.metadata import version
from pathlib import Path
import torch
def with_temp_dir(test_func):
@wraps(test_func)
@@ -35,13 +37,18 @@ def most_recent_subdir(path):
return subdir
def require_torch_2_1_1(test_case):
def require_torch_2_3_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
Decorator marking a test that requires torch >= 2.3.1
"""
def is_min_2_1_1():
def is_min_2_3_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return torch_version >= "2.3.1"
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)

View File

@@ -367,43 +367,44 @@ class TestDatasetPreparation(unittest.TestCase):
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
with tempfile.TemporaryDirectory() as tmp_dir2:
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":