Compare commits

..

46 Commits

Author SHA1 Message Date
NanoCode012
28e7e444ee fix: update bradleyterry to use new chat_template 2024-10-16 20:42:14 +07:00
NanoCode012
207e7627f9 fix(doc): formatting 2024-10-15 00:41:50 +07:00
NanoCode012
7eb62ae5a9 fix: update dummy message to prevent potential overlap with real content 2024-10-14 23:50:35 +07:00
NanoCode012
95805cf850 chore: lint 2024-10-14 23:43:30 +07:00
NanoCode012
4aafb7e600 fix: imported name incorrectly updated on merge 2024-10-14 23:41:17 +07:00
NanoCode012
17bc4c8b36 fix: update test based on new defaults 2024-10-14 18:03:35 +07:00
NanoCode012
d101cfc125 feat: handles chat_template requiring specific user/assistant order 2024-10-14 14:00:55 +07:00
NanoCode012
e5cd55cff9 feat: add example using fallback 2024-10-14 12:22:22 +07:00
NanoCode012
24aa6b15a0 feat: handle sharegpt deprecation better in docs 2024-10-14 12:21:58 +07:00
NanoCode012
9dfc5fa8b8 fix: remove default setting on edge case where chat template overriden in dataset section 2024-10-14 11:48:40 +07:00
NanoCode012
0c3255288f Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-14 10:36:08 +07:00
Chirag Jain
82b5dc9328 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-13 16:27:10 +05:30
Chirag Jain
ec57918fcd Merge pull request #7 from NanoCode012/cj_tokenizer_default_prompt_template
Feat: merge latest, update docs, fix dropped config bug, added unit test
2024-10-11 14:44:25 +05:30
NanoCode012
dd87d8c438 feat: add test for levy's dpo case 2024-10-11 12:56:46 +07:00
NanoCode012
ef942b6efc fix: rename var after merge 2024-10-11 12:30:43 +07:00
NanoCode012
3c6a6c61be Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-11 12:29:34 +07:00
NanoCode012
7b4b665e99 chore: skip duplicate 2024-10-11 11:42:36 +07:00
NanoCode012
21326e4ef3 chore: lint 2024-10-11 11:40:42 +07:00
NanoCode012
de23dab4fc fix: config being dropped and unittest to catch that 2024-10-11 11:40:32 +07:00
NanoCode012
e3efa29cf5 fix: test 2024-10-11 11:11:19 +07:00
NanoCode012
2038255052 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-10 20:25:37 +07:00
NanoCode012
dab2590e4d chore: refactor 2024-10-10 18:07:00 +07:00
NanoCode012
e5162b7a41 chore: added example for non-default template 2024-10-10 18:04:33 +07:00
NanoCode012
b6321d2220 chore: clarify doc 2024-10-10 18:01:33 +07:00
NanoCode012
6b3cdfdb8e feat(doc): updated config with chat template options and clarified examples 2024-10-10 17:57:11 +07:00
NanoCode012
203ae28704 fix: refactor artifact left from main merge 2024-10-10 17:16:41 +07:00
NanoCode012
ed3a33c9fb fix: re-arrange enum declaration position 2024-10-10 16:18:15 +07:00
NanoCode012
f61e2fc7dc chore: remove redundant function 2024-10-10 16:15:15 +07:00
NanoCode012
b8056d04d9 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-10 16:11:07 +07:00
NanoCode012
88658c0570 fix: set default to tokenizer template 2024-10-10 15:38:19 +07:00
Chirag Jain
260ca97f2c Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-09-13 00:33:49 +05:30
Chirag Jain
b1bb2accb9 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-28 13:34:20 +05:30
Chirag Jain
efeaa00bb4 Update docs/dataset-formats/conversation.qmd
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2024-08-27 19:08:54 +05:30
Chirag Jain
8a84408fc7 Address review comments and add docs 2024-08-27 04:30:35 +05:30
Chirag Jain
4805f3ca0a Merge branch 'main' of https://github.com/OpenAccess-AI-Collective/axolotl into cj_tokenizer_default_prompt_template 2024-08-27 02:35:58 +05:30
Chirag Jain
8ee30f5954 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-23 03:44:25 +05:30
Chirag Jain
6ef76f1ace remove custom mistral template 2024-08-19 15:56:47 +05:30
Chirag Jain
2e758aed6f Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-19 15:52:04 +05:30
Chirag Jain
21a2302538 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-12 10:24:02 +05:30
Chirag Jain
89f382a13a Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-06 21:23:14 +05:30
Chirag Jain
eb188acbd4 Add option chat_template_jinja to provide a jinja template 2024-07-31 01:43:40 +05:30
Chirag Jain
34ea51dcf3 Fix lint and bug post merge from main 2024-07-30 23:59:38 +05:30
Chirag Jain
fd7538dca7 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-07-30 23:48:43 +05:30
Chirag Jain
99b3bc7fbd Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-07-23 17:16:49 +05:30
Chirag Jain
4e38cea6b8 Add tests 2024-07-12 09:04:59 +05:30
Chirag Jain
5edaad5b8b Allow using tokenizer's default chat template with fallbacks
Summary of changes:

1. Adds `tokenizer_default` as option for `chat_template` in
   `chat_template` prompt strategy that allows using the chat template
   from tokenizer's config.json
2. Allows falling back to chat templates available in axolotl if
   tokenizer does not have a chat template
3. Adds a mistral chat template which supports system message - taken
   from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja

---

Why?

Many popular models are not trained with chatml format. As a result for
the model to correctly learn chatml we have to turn on train_on_inputs
which requires more compute and time. If we can use the model's already
learned chat template we can just learn the output tokens

---

Todo:

- Write tests
2024-07-12 08:42:26 +05:30
41 changed files with 865 additions and 1965 deletions

View File

@@ -36,12 +36,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -29,11 +29,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -21,17 +21,10 @@ jobs:
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
- cuda: 124 - cuda: 121
cuda_version: 12.4.1 cuda_version: 12.1.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"

View File

@@ -28,11 +28,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -90,11 +85,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
run: | run: |
pip3 install wheel packaging pip3 install wheel packaging
pip3 install -e . pip3 install -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Extract tag name - name: Extract tag name
id: tag id: tag

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"] pytorch_version: ["2.3.1", "2.4.1"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -47,14 +47,13 @@ jobs:
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging
pip3 install -U -e . pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |
@@ -96,13 +95,6 @@ jobs:
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"] pytorch_version: ["2.3.1", "2.4.1"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -49,20 +49,16 @@ jobs:
python-version: ${{ matrix.python_version }} python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
pip3 install torch==${{ matrix.pytorch_version }} pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e . pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |
@@ -76,7 +72,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 90 timeout-minutes: 60
needs: [pre-commit, pytest] needs: [pre-commit, pytest]
strategy: strategy:
@@ -101,12 +97,6 @@ jobs:
pytorch: 2.4.1 pytorch: 2.4.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

295
1991.yml
View File

@@ -1,295 +0,0 @@
base_model: Qwen/Qwen2.5-14B-Instruct
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
chat_template: chatml
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
# input_layernorm layers
- model.layers.0.input_layernorm
- model.layers.1.input_layernorm
- model.layers.2.input_layernorm
- model.layers.3.input_layernorm
- model.layers.4.input_layernorm
- model.layers.5.input_layernorm
- model.layers.6.input_layernorm
- model.layers.7.input_layernorm
- model.layers.8.input_layernorm
- model.layers.9.input_layernorm
- model.layers.10.input_layernorm
- model.layers.11.input_layernorm
- model.layers.12.input_layernorm
- model.layers.13.input_layernorm
- model.layers.14.input_layernorm
- model.layers.15.input_layernorm
- model.layers.16.input_layernorm
- model.layers.17.input_layernorm
- model.layers.18.input_layernorm
- model.layers.19.input_layernorm
- model.layers.20.input_layernorm
- model.layers.21.input_layernorm
- model.layers.22.input_layernorm
- model.layers.23.input_layernorm
# lm_head layers
# mlp.down_proj layers
- model.layers.1.mlp.down_proj
- model.layers.35.mlp.down_proj
- model.layers.38.mlp.down_proj
- model.layers.37.mlp.down_proj
- model.layers.36.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.34.mlp.down_proj
- model.layers.44.mlp.down_proj
- model.layers.45.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.41.mlp.down_proj
- model.layers.33.mlp.down_proj
- model.layers.43.mlp.down_proj
- model.layers.40.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.39.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.31.mlp.down_proj
- model.layers.32.mlp.down_proj
# mlp.gate_proj layers
- model.layers.1.mlp.gate_proj
- model.layers.44.mlp.gate_proj
- model.layers.46.mlp.gate_proj
- model.layers.45.mlp.gate_proj
- model.layers.43.mlp.gate_proj
- model.layers.47.mlp.gate_proj
- model.layers.42.mlp.gate_proj
- model.layers.32.mlp.gate_proj
- model.layers.27.mlp.gate_proj
- model.layers.33.mlp.gate_proj
- model.layers.28.mlp.gate_proj
- model.layers.39.mlp.gate_proj
- model.layers.41.mlp.gate_proj
- model.layers.40.mlp.gate_proj
- model.layers.30.mlp.gate_proj
- model.layers.29.mlp.gate_proj
- model.layers.31.mlp.gate_proj
- model.layers.26.mlp.gate_proj
- model.layers.37.mlp.gate_proj
- model.layers.10.mlp.gate_proj
- model.layers.38.mlp.gate_proj
- model.layers.12.mlp.gate_proj
- model.layers.36.mlp.gate_proj
- model.layers.13.mlp.gate_proj
# mlp.up_proj layers
- model.layers.1.mlp.up_proj
- model.layers.13.mlp.up_proj
- model.layers.11.mlp.up_proj
- model.layers.14.mlp.up_proj
- model.layers.15.mlp.up_proj
- model.layers.12.mlp.up_proj
- model.layers.8.mlp.up_proj
- model.layers.16.mlp.up_proj
- model.layers.9.mlp.up_proj
- model.layers.19.mlp.up_proj
- model.layers.10.mlp.up_proj
- model.layers.7.mlp.up_proj
- model.layers.17.mlp.up_proj
- model.layers.20.mlp.up_proj
- model.layers.21.mlp.up_proj
- model.layers.18.mlp.up_proj
- model.layers.38.mlp.up_proj
- model.layers.37.mlp.up_proj
- model.layers.39.mlp.up_proj
- model.layers.42.mlp.up_proj
- model.layers.41.mlp.up_proj
- model.layers.27.mlp.up_proj
- model.layers.28.mlp.up_proj
- model.layers.34.mlp.up_proj
# model.norm layers
# post_attention_layernorm layers
- model.layers.0.post_attention_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.2.post_attention_layernorm
- model.layers.3.post_attention_layernorm
- model.layers.4.post_attention_layernorm
- model.layers.5.post_attention_layernorm
- model.layers.6.post_attention_layernorm
- model.layers.7.post_attention_layernorm
- model.layers.8.post_attention_layernorm
- model.layers.9.post_attention_layernorm
- model.layers.10.post_attention_layernorm
- model.layers.11.post_attention_layernorm
- model.layers.12.post_attention_layernorm
- model.layers.13.post_attention_layernorm
- model.layers.14.post_attention_layernorm
- model.layers.15.post_attention_layernorm
- model.layers.16.post_attention_layernorm
- model.layers.17.post_attention_layernorm
- model.layers.18.post_attention_layernorm
- model.layers.19.post_attention_layernorm
- model.layers.20.post_attention_layernorm
- model.layers.21.post_attention_layernorm
- model.layers.22.post_attention_layernorm
- model.layers.23.post_attention_layernorm
# self_attn.k_proj layers
- model.layers.47.self_attn.k_proj
- model.layers.39.self_attn.k_proj
- model.layers.41.self_attn.k_proj
- model.layers.37.self_attn.k_proj
- model.layers.35.self_attn.k_proj
- model.layers.44.self_attn.k_proj
- model.layers.38.self_attn.k_proj
- model.layers.14.self_attn.k_proj
- model.layers.7.self_attn.k_proj
- model.layers.12.self_attn.k_proj
- model.layers.11.self_attn.k_proj
- model.layers.32.self_attn.k_proj
- model.layers.10.self_attn.k_proj
- model.layers.8.self_attn.k_proj
- model.layers.9.self_attn.k_proj
- model.layers.6.self_attn.k_proj
- model.layers.45.self_attn.k_proj
- model.layers.42.self_attn.k_proj
- model.layers.5.self_attn.k_proj
- model.layers.40.self_attn.k_proj
- model.layers.33.self_attn.k_proj
- model.layers.0.self_attn.k_proj
- model.layers.34.self_attn.k_proj
- model.layers.13.self_attn.k_proj
# self_attn.o_proj layers
- model.layers.12.self_attn.o_proj
- model.layers.5.self_attn.o_proj
- model.layers.14.self_attn.o_proj
- model.layers.16.self_attn.o_proj
- model.layers.20.self_attn.o_proj
- model.layers.13.self_attn.o_proj
- model.layers.11.self_attn.o_proj
- model.layers.4.self_attn.o_proj
- model.layers.6.self_attn.o_proj
- model.layers.19.self_attn.o_proj
- model.layers.7.self_attn.o_proj
- model.layers.18.self_attn.o_proj
- model.layers.8.self_attn.o_proj
- model.layers.38.self_attn.o_proj
- model.layers.15.self_attn.o_proj
- model.layers.17.self_attn.o_proj
- model.layers.9.self_attn.o_proj
- model.layers.10.self_attn.o_proj
- model.layers.21.self_attn.o_proj
- model.layers.28.self_attn.o_proj
- model.layers.32.self_attn.o_proj
- model.layers.35.self_attn.o_proj
- model.layers.39.self_attn.o_proj
- model.layers.3.self_attn.o_proj
# self_attn.q_proj layers
- model.layers.1.self_attn.q_proj
- model.layers.2.self_attn.q_proj
- model.layers.3.self_attn.q_proj
- model.layers.44.self_attn.q_proj
- model.layers.29.self_attn.q_proj
- model.layers.45.self_attn.q_proj
- model.layers.43.self_attn.q_proj
- model.layers.32.self_attn.q_proj
- model.layers.38.self_attn.q_proj
- model.layers.19.self_attn.q_proj
- model.layers.42.self_attn.q_proj
- model.layers.34.self_attn.q_proj
- model.layers.36.self_attn.q_proj
- model.layers.40.self_attn.q_proj
- model.layers.26.self_attn.q_proj
- model.layers.20.self_attn.q_proj
- model.layers.39.self_attn.q_proj
- model.layers.28.self_attn.q_proj
- model.layers.35.self_attn.q_proj
- model.layers.41.self_attn.q_proj
- model.layers.33.self_attn.q_proj
- model.layers.25.self_attn.q_proj
- model.layers.30.self_attn.q_proj
- model.layers.27.self_attn.q_proj
# self_attn.v_proj layers
- model.layers.0.self_attn.v_proj
- model.layers.7.self_attn.v_proj
- model.layers.39.self_attn.v_proj
- model.layers.31.self_attn.v_proj
- model.layers.15.self_attn.v_proj
- model.layers.10.self_attn.v_proj
- model.layers.32.self_attn.v_proj
- model.layers.41.self_attn.v_proj
- model.layers.6.self_attn.v_proj
- model.layers.33.self_attn.v_proj
- model.layers.42.self_attn.v_proj
- model.layers.29.self_attn.v_proj
- model.layers.14.self_attn.v_proj
- model.layers.9.self_attn.v_proj
- model.layers.35.self_attn.v_proj
- model.layers.38.self_attn.v_proj
- model.layers.13.self_attn.v_proj
- model.layers.30.self_attn.v_proj
- model.layers.5.self_attn.v_proj
- model.layers.34.self_attn.v_proj
- model.layers.28.self_attn.v_proj
- model.layers.37.self_attn.v_proj
- model.layers.27.self_attn.v_proj
- model.layers.11.self_attn.v_proj
# model.embed_tokens layers
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: linear
learning_rate: 5e-6
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
gradient_checkpointing: unsloth
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
save_total_limit: 4
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.05
special_tokens:
eos_token: <|im_end|>

View File

@@ -121,7 +121,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task. Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1. **Requirements**: Python >=3.10 and Pytorch >=2.1.1.
```bash ```bash
git clone https://github.com/axolotl-ai-cloud/axolotl git clone https://github.com/axolotl-ai-cloud/axolotl

View File

@@ -23,11 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
fi fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -37,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi fi
# So we can test the Docker image # So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
set -e set -e
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function( @stub.function(
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=60 * 60, timeout=45 * 60,
cpu=8.0, cpu=8.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
) )

View File

@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function( @stub.function(
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=60 * 60, timeout=45 * 60,
cpu=8.0, cpu=8.0,
memory=131072, memory=131072,
) )

View File

@@ -14,6 +14,15 @@
"bf16": { "bf16": {
"enabled": true "enabled": true
}, },
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto", "gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@@ -24,6 +24,15 @@
"bf16": { "bf16": {
"enabled": true "enabled": true
}, },
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto", "gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@@ -20,6 +20,15 @@
"bf16": { "bf16": {
"enabled": true "enabled": true
}, },
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto", "gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@@ -20,6 +20,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -11,6 +11,7 @@ rl: dpo
datasets: datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test - path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default type: chat_template.default
chat_template: llama3
field_messages: conversation field_messages: conversation
field_chosen: chosen field_chosen: chosen
field_rejected: rejected field_rejected: rejected

View File

@@ -10,6 +10,7 @@ chat_template: llama3
datasets: datasets:
- path: fozziethebeat/alpaca_messages_2k_test - path: fozziethebeat/alpaca_messages_2k_test
type: chat_template type: chat_template
chat_template: llama3
field_messages: messages field_messages: messages
message_field_role: role message_field_role: role
message_field_content: content message_field_content: content

View File

@@ -1,77 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -2,4 +2,3 @@ pre-commit
black black
mypy mypy
types-requests types-requests
tbparse

View File

@@ -1,12 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.13.2 peft==0.13.2
transformers==4.46.0 transformers==4.45.2
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.44.1 bitsandbytes==0.44.1
accelerate==1.0.1 accelerate==1.0.1
datasets==3.0.1 datasets==3.0.1
deepspeed==0.15.3 deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.23.post1 xformers==0.0.28.post1
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0 gcsfs>=2024.5.0
# adlfs # adlfs
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924 trl==0.9.6
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -31,8 +31,6 @@ def parse_requirements():
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS # don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
@@ -52,16 +50,10 @@ def parse_requirements():
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 5): if (major, minor) >= (2, 4):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27") _install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3): elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(torchao_version))
if patch == 0: if patch == 0:
@@ -81,6 +73,7 @@ def parse_requirements():
except PackageNotFoundError: except PackageNotFoundError:
pass pass
return _install_requires, _dependency_links return _install_requires, _dependency_links
@@ -109,7 +102,6 @@ setup(
], ],
"mamba-ssm": [ "mamba-ssm": [
"mamba-ssm==1.2.0.post1", "mamba-ssm==1.2.0.post1",
"causal_conv1d",
], ],
"auto-gptq": [ "auto-gptq": [
"auto-gptq==0.5.1", "auto-gptq==0.5.1",

View File

@@ -462,12 +462,7 @@ def load_datasets(
processor=processor, processor=processor,
) )
if ( if cli_args.debug or cfg.debug:
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_dataset.select(

View File

@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0) debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)

View File

@@ -7,7 +7,6 @@ import abc
import gc import gc
import importlib import importlib
import importlib.util import importlib.util
import inspect
import logging import logging
import math import math
import os import os
@@ -28,6 +27,7 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
PreTrainedModel,
Trainer, Trainer,
TrainerCallback, TrainerCallback,
TrainingArguments, TrainingArguments,
@@ -666,9 +666,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params) return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss( def compute_loss(self, model, inputs, return_outputs=False):
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc
# if self.args.sample_packing: # if self.args.sample_packing:
# labels = inputs.pop("labels") # labels = inputs.pop("labels")
@@ -676,18 +674,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha: if self.args.orpo_alpha:
return self.orpo_compute_loss( return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
model, return super().compute_loss(model, inputs, return_outputs=return_outputs)
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -783,13 +771,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2) ).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss( def orpo_compute_loss(self, model, inputs, return_outputs=False):
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs, inputs,
label_pad_token=-100, label_pad_token=-100,
@@ -895,13 +877,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial): def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey # make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial) run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial) return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
@@ -916,7 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model, model,
inputs, inputs,
return_outputs=False, # pylint: disable=unused-argument return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
): ):
input_ids = inputs.pop("input_ids") input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits lm_logits = model(input_ids).logits
@@ -1024,32 +1005,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
def tokenize_row( def tokenize_row(
self, self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict: ) -> Dict:
res = super().tokenize_row( res = super().tokenize_row(feature, model=model)
features, if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys(): for key in res.keys():
res[key] = res[key][1:] res[key] = res[key][1:]
return res return res
def training_step( def training_step(
self, self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor: ) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) loss: torch.Tensor = super().training_step(model, inputs)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
@@ -1152,17 +1119,12 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_mlflow and is_mlflow_available(): if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import ( from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,
) )
callbacks.extend( callbacks.append(
[ SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
) )
if self.cfg.use_comet and is_comet_available(): if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -1700,17 +1662,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,
) )
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls( trainer = trainer_cls(
model=self.model, model=self.model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs), data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**trainer_kwargs, **trainer_kwargs,
@@ -1751,8 +1708,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] ]
if self.cfg.reward_model: if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
@@ -1955,7 +1910,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
@@ -1967,17 +1922,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(
*trainer_cls_args, *trainer_cls_args,
args=training_args, args=training_args,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**dpo_trainer_kwargs, **dpo_trainer_kwargs,
) )

View File

@@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -43,19 +44,7 @@ except ImportError:
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def is_xformers_available() -> bool:
try:
import xformers # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
def is_xformers_swiglu_available() -> bool: def is_xformers_swiglu_available() -> bool:
if not is_xformers_available():
return False
from xformers.ops.common import get_xformers_operator from xformers.ops.common import get_xformers_operator
try: try:
@@ -68,11 +57,6 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model): def replace_llama_mlp_with_swiglu(model):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, LlamaMLP): if isinstance(module, LlamaMLP):
mlp = FusedMLP( mlp = FusedMLP(
@@ -197,6 +181,49 @@ class FusedAttention(LlamaAttention):
set_module_name(model, name, new_attn) set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(

View File

@@ -27,18 +27,15 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
] ]
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False): def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if model_type == "gemmoe": if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2": elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
elif hasattr(transformers, "modeling_flash_attention_utils"): transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
if not has_remote_code: get_unpad_data
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access )
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled(): if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3() patch_mixtral_moe_forward_zero3()
return return

View File

@@ -16,6 +16,26 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """ ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
@@ -60,6 +80,12 @@ def get_forward_code() -> str:
return forward return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str: def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward) forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward return forward
@@ -72,31 +98,48 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama": if model_type == "llama":
from transformers.loss import loss_utils forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment] forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else: else:
raise ValueError("Unsupported model type") raise ValueError("Unsupported model type")

View File

@@ -1,51 +0,0 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import set_module_name
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)

View File

@@ -260,10 +260,8 @@ def train(
if not cfg.hub_model_id: if not cfg.hub_model_id:
try: try:
trainer.create_model_card( trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8") except AttributeError:
)
except (AttributeError, UnicodeDecodeError):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated # defensively push to the hub to ensure the model card is updated

View File

@@ -583,7 +583,6 @@ class AxolotlInputConfig(
resume_from_checkpoint: Optional[str] = None resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
rl: Optional[RLType] = None rl: Optional[RLType] = None
reward_model: Optional[bool] = None reward_model: Optional[bool] = None

View File

@@ -16,7 +16,3 @@ def setup_mlflow_env_vars(cfg: DictDefault):
# Enable mlflow if experiment name is present # Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0: if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True cfg.use_mlflow = True
# Enable logging hf artifacts in mlflow if value is truthy
if cfg.hf_mlflow_log_artifacts is True:
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"

File diff suppressed because it is too large Load Diff

View File

@@ -133,8 +133,6 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.len_across_ranks = None
def set_epoch(self, epoch: int): def set_epoch(self, epoch: int):
self.epoch = epoch self.epoch = epoch
@@ -197,14 +195,15 @@ class MultipackBatchSampler(BatchSampler):
LOG.info(f"gather_len_batches: {repr(estimates)}") LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates)) return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
return min_len_batches return min_len_batches
def __len__(self): def __len__(self):
if not self.len_across_ranks: len_batches = self.num_batches()
len_batches = self.num_batches() return self.gather_len_batches(len_batches)
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
def _len_est(self): def _len_est(self):
efficiency = ( efficiency = (

View File

@@ -1,155 +0,0 @@
"""
E2E tests for multigpu eval
"""
import logging
import os
import unittest
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
class TestMultiGPUEval(unittest.TestCase):
"""
Test case for MultiGPU Eval Sample Packing
"""
@with_temp_dir
def test_eval_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -1,12 +1,22 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest import unittest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable from axolotl.monkeypatch.unsloth_ import (
check_cel_is_patchable,
check_self_attn_is_patchable,
)
class TestUnslothIntegration(unittest.TestCase): class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""
def test_is_cel_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_cel_is_patchable(),
"HF transformers loss code has changed and isn't patchable",
)
def test_is_self_attn_patchable(self): def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(

View File

@@ -1,95 +0,0 @@
"""Module for testing ModelLoader."""
import shutil
import tempfile
import pytest
import torch
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
def fixture_temp_dir():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
class TestLoadModelUtils:
"""
Testing module testing ModelLoader.
"""
def setup_method(self):
# load config
self.cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"tokenizer_config": "JackFram/llama-68m",
"sequence_len": 1024,
"load_in_8bit": False,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer="",
)
)
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
@pytest.mark.parametrize(
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
def test_convert_embedding_modules_dtype(
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():
if (
"norm" in name
or (before_kbit_train_or_finetune and name.endswith(".gate"))
or (
any(m in name for m in embedding_modules)
and hasattr(module, "weight")
)
):
for _, param in module.named_parameters():
assert param.dtype == dist_dtype

View File

@@ -1,74 +0,0 @@
"""
E2E tests for packed training
"""
import logging
import os
import unittest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase):
"""
Test case for Packed training of llama models
"""
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -13,7 +13,6 @@ from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -1433,58 +1432,3 @@ class TestValidationComet(BaseValidation):
for key in comet_env.keys(): for key in comet_env.keys():
os.environ.pop(key, None) os.environ.pop(key, None)
class TestValidationMLflow(BaseValidation):
"""
Validation test for MLflow
"""
def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):
cfg = (
DictDefault(
{
"hf_mlflow_log_artifacts": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.hf_mlflow_log_artifacts is True
# Check it's not already present in env
assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ
setup_mlflow_env_vars(new_cfg)
assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true"
os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None)
def test_mlflow_not_used_by_default(self, minimal_cfg):
cfg = DictDefault({}) | minimal_cfg
new_cfg = validate_config(cfg)
setup_mlflow_env_vars(new_cfg)
assert cfg.use_mlflow is not True
cfg = (
DictDefault(
{
"mlflow_experiment_name": "foo",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
setup_mlflow_env_vars(new_cfg)
assert new_cfg.use_mlflow is True
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)

View File

@@ -1,64 +1,18 @@
"""Module for testing models utils file.""" """Module for testing models utils file."""
from unittest.mock import MagicMock, patch
import unittest
from unittest.mock import patch
import pytest import pytest
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_torch_mps_available
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model from axolotl.utils.models import load_model
class TestModelsUtils: class ModelsUtilsTest(unittest.TestCase):
"""Testing module for models utils.""" """Testing module for models utils."""
def setup_method(self) -> None:
# load config
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
{
"base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"load_in_8bit": True,
"load_in_4bit": False,
"adapter": "lora",
"flash_attention": False,
"sample_packing": True,
"device_map": "auto",
}
)
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
spec=PreTrainedTokenizerBase
)
self.inference = False # pylint: disable=attribute-defined-outside-init
self.reference_model = True # pylint: disable=attribute-defined-outside-init
# init ModelLoader
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
)
def test_set_device_map_config(self):
# check device_map
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
self.model_loader.set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
else:
assert device_map in self.model_loader.model_kwargs["device_map"]
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -81,38 +35,3 @@ class TestModelsUtils:
"shifted-sparse attention does not currently support sample packing" "shifted-sparse attention does not currently support sample packing"
in str(exc.value) in str(exc.value)
) )
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@pytest.mark.parametrize("gptq", [True, False])
def test_set_quantization_config(
self,
adapter,
load_in_8bit,
load_in_4bit,
gptq,
):
# init cfg as args
self.cfg.load_in_8bit = load_in_8bit
self.cfg.load_in_4bit = load_in_4bit
self.cfg.gptq = gptq
self.cfg.adapter = adapter
self.model_loader.set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
)
elif load_in_8bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_8bit"]
elif load_in_4bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_4bit"]
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
self.cfg.adapter == "lora" and load_in_8bit
):
assert self.model_loader.model_kwargs.get(
"quantization_config", BitsAndBytesConfig
)