diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml
index d215ea44c..f8eaff270 100644
--- a/.github/workflows/base.yml
+++ b/.github/workflows/base.yml
@@ -37,6 +37,11 @@ jobs:
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
+ - cuda: "121"
+ cuda_version: 12.1.0
+ python_version: "3.11"
+ pytorch: 2.3.1
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 8bced628d..4969de75d 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -19,7 +19,6 @@ jobs:
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -33,8 +32,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
- pytorch: 2.3.0
+ pytorch: 2.3.1
axolotl_extras:
+ is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -80,7 +80,6 @@ jobs:
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -94,8 +93,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
- pytorch: 2.3.0
+ pytorch: 2.3.1
axolotl_extras:
+ is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -136,7 +136,7 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
- pytorch: 2.3.0
+ pytorch: 2.3.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml
index 6dc22b6bf..770954b85 100644
--- a/.github/workflows/nightlies.yml
+++ b/.github/workflows/nightlies.yml
@@ -18,7 +18,6 @@ jobs:
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -32,8 +31,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
- pytorch: 2.3.0
+ pytorch: 2.3.1
axolotl_extras:
+ is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -80,7 +80,6 @@ jobs:
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- is_latest: true
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
@@ -94,8 +93,9 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
- pytorch: 2.3.0
+ pytorch: 2.3.1
axolotl_extras:
+ is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 2e2d0968d..1cee8cbcb 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -57,6 +57,10 @@ jobs:
run: |
pytest --ignore=tests/e2e/ tests/
+ - name: cleanup pip cache
+ run: |
+ find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
+
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
@@ -87,7 +91,7 @@ jobs:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
- pytorch: 2.3.0
+ pytorch: 2.3.1
num_gpus: 1
steps:
- name: Checkout
@@ -99,7 +103,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
- pip install modal jinja2
+ pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
diff --git a/README.md b/README.md
index 6b35702b3..fd293bd04 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,7 @@ Features:
- [Multipack](./docs/multipack.qmd)
- [RLHF & DPO](./docs/rlhf.qmd)
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)
+ - [Unsloth](./docs/unsloth.qmd)
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
@@ -333,7 +334,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
-See [these docs](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
+See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config
diff --git a/_quarto.yml b/_quarto.yml
index 009fa8056..6b2eed971 100644
--- a/_quarto.yml
+++ b/_quarto.yml
@@ -36,6 +36,7 @@ website:
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-node.qmd
+ - docs/unsloth.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja
index 96c312ddc..263f4a661 100644
--- a/cicd/Dockerfile.jinja
+++ b/cicd/Dockerfile.jinja
@@ -24,13 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
- pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
- pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
-RUN pip install pytest
+RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
diff --git a/cicd/cicd.sh b/cicd/cicd.sh
index bc36458ab..180150ea2 100755
--- a/cicd/cicd.sh
+++ b/cicd/cicd.sh
@@ -2,5 +2,5 @@
set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
-pytest /workspace/axolotl/tests/e2e/patched/
+pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
diff --git a/docker/Dockerfile b/docker/Dockerfile
index cdb6d177a..be58d0354 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
- pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
- pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
diff --git a/docs/torchao.qmd b/docs/torchao.qmd
new file mode 100644
index 000000000..2dc9117fb
--- /dev/null
+++ b/docs/torchao.qmd
@@ -0,0 +1,19 @@
+---
+title: "PyTorch ao"
+description: "Custom data types and layouts for training and inference"
+---
+
+### Installation
+
+Stable Release from the PyTorch index
+
+```bash
+pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
+```
+
+
+Nightly release
+
+```bash
+pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
+```
diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd
new file mode 100644
index 000000000..390609fd3
--- /dev/null
+++ b/docs/unsloth.qmd
@@ -0,0 +1,49 @@
+---
+title: "Unsloth"
+description: "Hyper-optimized QLoRA finetuning for single GPUs"
+---
+
+### Overview
+
+Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
+standard industry baselines.
+
+
+### Installation
+
+The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
+to date libraries.
+
+```bash
+pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
+pip install --no-deps --force-reinstall xformers==0.0.26.post1
+```
+
+### Using unsloth w Axolotl
+
+Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
+
+Our unsloth integration is currently limited to the following model architectures:
+ - llama
+
+These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
+```yaml
+unsloth_lora_mlp: true
+unsloth_lora_qkv: true
+unsloth_lora_o: true
+```
+
+These options are composable and can be used with multi-gpu finetuning
+```
+unsloth_cross_entropy_loss: true
+unsloth_rms_norm: true
+unsloth_rope: true
+```
+
+### Limitations
+
+- Single GPU only; e.g. no multi-gpu support
+- No deepspeed or FSDP support (requires multi-gpu)
+- LoRA + QLoRA support only. No full fine tunes or fp8 support.
+- Limited model architecture support. Llama, Phi, Gemma, Mistral only
+- No MoE support.
diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb
index 94477eb19..3fcc4d2a9 100644
--- a/examples/colab-notebooks/colab-axolotl-example.ipynb
+++ b/examples/colab-notebooks/colab-axolotl-example.ipynb
@@ -171,7 +171,7 @@
},
"outputs": [],
"source": [
- "# Buy using the ! the comand will be executed as a bash command\n",
+ "# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
@@ -188,7 +188,7 @@
"metadata": {},
"outputs": [],
"source": [
- "# Buy using the ! the comand will be executed as a bash command\n",
+ "# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml
index a36fd740e..908ef6e03 100644
--- a/examples/llama-3/fft-8b.yaml
+++ b/examples/llama-3/fft-8b.yaml
@@ -1,4 +1,4 @@
-base_model: meta-llama/Meta-Llama-3-8B
+base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml
new file mode 100644
index 000000000..14febb810
--- /dev/null
+++ b/examples/llama-3/instruct-dpo-lora-8b.yml
@@ -0,0 +1,81 @@
+base_model: meta-llama/Meta-Llama-3-8B-Instruct
+model_type: LlamaForCausalLM
+tokenizer_type: AutoTokenizer
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+chat_template: llama3
+rl: dpo
+datasets:
+ - path: fozziethebeat/alpaca_messages_2k_dpo_test
+ type: chat_template.default
+ chat_template: llama3
+ field_messages: conversation
+ field_chosen: chosen
+ field_rejected: rejected
+ message_field_role: role
+ message_field_content: content
+ roles:
+ system:
+ - system
+ user:
+ - user
+ assistant:
+ - assistant
+
+dataset_prepared_path:
+val_set_size: 0.05
+output_dir: ./outputs/lora-out
+
+sequence_len: 4096
+sample_packing: false
+pad_to_sequence_len: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 4
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+s2_attention:
+
+warmup_steps: 10
+evals_per_epoch: 4
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml
index 754c9ad5c..21d32604c 100644
--- a/examples/llama-3/instruct-lora-8b.yml
+++ b/examples/llama-3/instruct-lora-8b.yml
@@ -1,4 +1,4 @@
-base_model: meta-llama/Meta-Llama-3-8B-Instruct
+base_model: NousResearch/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml
index dfc881faf..a20a529f5 100644
--- a/examples/llama-3/lora-8b.yml
+++ b/examples/llama-3/lora-8b.yml
@@ -1,4 +1,4 @@
-base_model: meta-llama/Meta-Llama-3-8B
+base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
@@ -15,6 +15,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
+eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml
index 44120d938..079c9cad0 100644
--- a/examples/llama-3/qlora.yml
+++ b/examples/llama-3/qlora.yml
@@ -1,4 +1,4 @@
-base_model: meta-llama/Meta-Llama-3-8B
+base_model: NousResearch/Meta-Llama-3-8B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
diff --git a/requirements-tests.txt b/requirements-tests.txt
index e079f8a60..9cda381d0 100644
--- a/requirements-tests.txt
+++ b/requirements-tests.txt
@@ -1 +1,2 @@
pytest
+pytest-xdist
diff --git a/requirements.txt b/requirements.txt
index c8d168734..b2aac0dd0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
-transformers==4.42.3
+transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0
@@ -12,11 +12,11 @@ fire
PyYAML>=6.0
requests
datasets==2.19.1
-flash-attn==2.5.8
+flash-attn==2.6.1
sentencepiece
wandb
einops
-xformers==0.0.26.post1
+xformers==0.0.27
optimum==1.16.2
hf_transfer
colorama
diff --git a/setup.py b/setup.py
index c7b4e15de..9e6f34ad8 100644
--- a/setup.py
+++ b/setup.py
@@ -29,9 +29,10 @@ def parse_requirements():
_install_requires.append(line)
try:
+ xformers_version = [req for req in _install_requires if "xformers" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
- _install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
+ _install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
@@ -49,12 +50,14 @@ def parse_requirements():
raise ValueError("Invalid version format")
if (major, minor) >= (2, 3):
- pass
+ if patch == 0:
+ _install_requires.pop(_install_requires.index(xformers_version))
+ _install_requires.append("xformers>=0.0.26.post1")
elif (major, minor) >= (2, 2):
- _install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
+ _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
- _install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
+ _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
@@ -77,10 +80,10 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
- "flash-attn==2.5.8",
+ "flash-attn==2.6.1",
],
"fused-dense-lib": [
- "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
+ "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
@@ -101,5 +104,11 @@ setup(
"galore": [
"galore_torch",
],
+ "optimizers": [
+ "galore_torch",
+ "lion-pytorch==0.1.2",
+ "lomo-optim==0.1.1",
+ "torch-optimi==0.2.1",
+ ],
},
)
diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py
index 7ec3f524a..5966d5931 100644
--- a/src/axolotl/cli/__init__.py
+++ b/src/axolotl/cli/__init__.py
@@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
- "n_gpu": os.environ.get("WORLD_SIZE", 1),
+ "n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
)
diff --git a/src/axolotl/core/tokenizer_utils.py b/src/axolotl/core/tokenizer_utils.py
new file mode 100644
index 000000000..53c44a75c
--- /dev/null
+++ b/src/axolotl/core/tokenizer_utils.py
@@ -0,0 +1,150 @@
+"""
+helper functions for fixing the embeddings/tokenizer
+"""
+
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import itertools
+
+import numpy as np
+import torch
+
+
+@torch.inference_mode
+def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
+ """
+ Many of the newer models have reserved tokens that are not trained.
+ """
+ embedding_matrix = model.get_input_embeddings().weight
+ lm_head_matrix = model.get_output_embeddings().weight
+
+ # Get untrained tokens
+ indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
+ where_untrained = torch.where(indicator_untrained)[0]
+ n_untrained = where_untrained.shape[0]
+ n_trained = embedding_matrix.shape[0] - n_untrained
+
+ # Get set and actual tokens
+ where_untrained = where_untrained.tolist()
+ if len(where_untrained) == 0:
+ return False
+
+ # Remove untrained indices where it's longer
+
+ where_untrained_set = frozenset(where_untrained)
+ actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
+ # Remove None items in actual_bad_tokens
+ actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
+
+ # Check if tokenizer and training datasets have bad tokens
+ if_bad_first = False
+ if_bad_second = False
+ # Check tokenizer's chat template for any untrained tokens
+ chat_template = getattr(tokenizer, "chat_template", None)
+ if chat_template is not None:
+ if_bad_first = any(x in chat_template for x in actual_bad_tokens)
+
+ # Check the first 250, last 250 input_ids
+ size_dataset = len(train_dataset)
+ size = min(size_dataset, 250)
+ for j in range(size):
+ input_ids = train_dataset[j]
+ if "input_ids" in input_ids:
+ input_ids = input_ids["input_ids"]
+ if_bad = any(item in where_untrained_set for item in input_ids)
+ if if_bad:
+ if_bad_second = True
+ break
+
+ # Check last 250
+ if not if_bad_second:
+ left = max(size_dataset - 250, 0)
+ for j in range(left, size_dataset):
+ input_ids = train_dataset[j]
+ if "input_ids" in input_ids:
+ input_ids = input_ids["input_ids"]
+ if_bad = any(item in where_untrained_set for item in input_ids)
+ if if_bad:
+ if_bad_second = True
+ break
+
+ # Check if bad tokens exists!
+ if not if_bad_first and not if_bad_second:
+ return False
+
+ # Count all the possible bad tokens
+ final_counts = np.zeros(
+ max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
+ )
+
+ def mapping(examples):
+ input_ids = examples["input_ids"]
+ counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
+ np.add.at(final_counts, counter, 1)
+
+ train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
+
+ # Get sum of all items
+ sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
+ sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
+
+ # Remove bad tokens
+ sum_embedding -= torch.sum(
+ embedding_matrix[where_untrained], dtype=torch.float32, axis=0
+ )
+ sum_lm_head -= torch.sum(
+ lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
+ )
+
+ # Find correct average by dividing by sum of trained tokens
+ mean_embedding = sum_embedding / n_trained
+ mean_lm_head = sum_lm_head / n_trained
+
+ # Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
+ scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
+ scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
+ mean_embedding = (
+ mean_embedding.repeat(
+ (
+ n_untrained,
+ 1,
+ )
+ )
+ * scaling
+ )
+ mean_lm_head = (
+ mean_lm_head.repeat(
+ (
+ n_untrained,
+ 1,
+ )
+ )
+ * scaling
+ )
+ where_null = scaling.ravel() == 0
+ mean_embedding[where_null] = 0
+ mean_lm_head[where_null] = 0
+
+ # Set them to the mean
+ embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
+ lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
+
+ # Clean up
+ for _ in range(3):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return True
diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py
index ec175454e..616b1d4eb 100755
--- a/src/axolotl/core/trainer_builder.py
+++ b/src/axolotl/core/trainer_builder.py
@@ -226,6 +226,12 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
+ alternate_optimizer: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "workaround to pass an alternate optimizer to the HF trainer"
+ },
+ )
@dataclass
@@ -284,26 +290,91 @@ class AxolotlTrainer(Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
+ def _wrap_model(self, model, training=True, dataloader=None):
+ if self.args.torch_compile:
+ torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
+ 256
+ )
+ model = torch.compile(
+ model,
+ backend=self.args.torch_compile_backend,
+ mode=self.args.torch_compile_mode,
+ )
+ return super()._wrap_model(model, training=training, dataloader=dataloader)
+
def create_optimizer(self):
- if self.args.loraplus_lr_ratio is None:
+ if (
+ self.args.loraplus_lr_ratio is None
+ and self.args.alternate_optimizer
+ not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
+ ):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
+ decay_parameters = self.get_decay_parameter_names(opt_model)
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.named_parameters()
+ if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
- loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
- loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
- self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
- opt_model,
- optimizer_cls,
- optimizer_kwargs,
- loraplus_lr_ratio,
- loraplus_lr_embedding,
- )
+ if self.args.loraplus_lr_ratio is not None:
+ loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
+ loraplus_lr_embedding = getattr(
+ self.args, "loraplus_lr_embedding", None
+ )
+ self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
+ opt_model,
+ optimizer_cls,
+ optimizer_kwargs,
+ loraplus_lr_ratio,
+ loraplus_lr_embedding,
+ )
+ elif self.args.alternate_optimizer == "optimi_adamw":
+ from optimi import AdamW
+
+ self.optimizer = ( # pylint: disable=attribute-defined-outside-init
+ AdamW(
+ optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
+ )
+ )
+ elif self.args.alternate_optimizer == "ao_adamw_4bit":
+ from torchao.prototype.low_bit_optim import AdamW4bit
+
+ self.optimizer = ( # pylint: disable=attribute-defined-outside-init
+ AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
+ )
+ elif self.args.alternate_optimizer == "ao_adamw_8bit":
+ from torchao.prototype.low_bit_optim import AdamW8bit
+
+ self.optimizer = ( # pylint: disable=attribute-defined-outside-init
+ AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
+ )
+ elif self.args.alternate_optimizer == "ao_adamw_fp8":
+ from torchao.prototype.low_bit_optim import AdamWFp8
+
+ self.optimizer = ( # pylint: disable=attribute-defined-outside-init
+ AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
+ )
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1235,6 +1306,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
+ if self.cfg.torch_compile_mode:
+ training_arguments_kwargs[
+ "torch_compile_mode"
+ ] = self.cfg.torch_compile_mode
# DDP Config
if self.cfg.ddp_timeout:
@@ -1396,6 +1471,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
+ if self.cfg.optimizer in [
+ "optimi_adamw",
+ "ao_adamw_4bit",
+ "ao_adamw_8bit",
+ "ao_adamw_fp8",
+ ]:
+ # Set default so transformers doesn't throw
+ training_arguments_kwargs["optim"] = "adamw_hf"
+ training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
+
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
@@ -1424,6 +1509,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
+ if self.cfg.accelerator_config:
+ training_arguments_kwargs[
+ "accelerator_config"
+ ] = self.cfg.accelerator_config
+
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
@@ -1621,6 +1711,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
+ training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
@@ -1688,8 +1779,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
- if self.cfg.rl == "dpo":
- dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
diff --git a/src/axolotl/integrations/__init__.py b/src/axolotl/integrations/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
index 6d7a23f0d..4c3571ea4 100644
--- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)
+def patch_llama_cross_entropy():
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
+
+ LOG.info("patching with flash_attn.losses.cross_entropy")
+ transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
+ CrossEntropyLoss, inplace_backward=True
+ )
+
+
+def patch_llama_rms_norm():
+ try:
+ from flash_attn.ops.rms_norm import RMSNorm
+
+ class LlamaRMSNorm(RMSNorm):
+ """Patched LLamaRMSNorm"""
+
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__(hidden_size, eps=eps)
+
+ LOG.info("patching with flash_attn.ops.rms_norm")
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
+ except ImportError:
+ LOG.warning(
+ "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
+ )
+
+
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
@@ -104,35 +131,11 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
- try:
- from flash_attn.losses.cross_entropy import CrossEntropyLoss
-
- LOG.info("patching with flash_attn.losses.cross_entropy")
- transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
- CrossEntropyLoss, inplace_backward=True
- )
- except ImportError:
- LOG.warning(
- "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
- )
+ patch_llama_cross_entropy()
# skip only if explicitly disabled
if rms_norm:
- try:
- from flash_attn.ops.rms_norm import RMSNorm
-
- class LlamaRMSNorm(RMSNorm):
- """Patched LLamaRMSNorm"""
-
- def __init__(self, hidden_size, eps=1e-6):
- super().__init__(hidden_size, eps=eps)
-
- LOG.info("patching with flash_attn.ops.rms_norm")
- transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
- except ImportError:
- LOG.warning(
- "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
- )
+ patch_llama_rms_norm()
class FusedAttention(LlamaAttention):
diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
index c5425dd52..1cbc4278b 100644
--- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
@@ -2,6 +2,7 @@
# pylint: disable=duplicate-code
import logging
+from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
)
+def patch_mistral_cross_entropy():
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
+
+ LOG.info("patching with flash_attn.losses.cross_entropy")
+ transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
+ CrossEntropyLoss, inplace_backward=True
+ )
+
+
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py
index e319596d0..a2ce0e64f 100644
--- a/src/axolotl/monkeypatch/multipack.py
+++ b/src/axolotl/monkeypatch/multipack.py
@@ -10,6 +10,8 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
+ "llama",
+ "mistral",
"mixtral",
"qwen2",
"qwen2_moe",
@@ -24,12 +26,35 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
def patch_for_multipack(model_type, model_name=None):
+ if model_type == "gemmoe":
+ patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
+ elif model_type == "deepseek_v2":
+ patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
+ elif hasattr(transformers, "modeling_flash_attention_utils"):
+ transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
+ get_unpad_data
+ )
+ if model_type == "mixtral" and is_deepspeed_zero3_enabled():
+ patch_mixtral_moe_forward_zero3()
+ return
+
+ # retain for legacy
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
+ elif model_type == "llama":
+ if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
+ transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
+ get_unpad_data
+ )
+ elif model_type == "mistral":
+ if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
+ transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
+ get_unpad_data
+ )
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
@@ -58,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
- elif model_type == "gemmoe":
- patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
- elif model_type == "jamba":
- patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
- elif model_type == "deepseek_v2":
- patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
def patch_remote(model_name, config_name, modeling_name):
diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py
index 6af3046e1..5b1f0061d 100644
--- a/src/axolotl/monkeypatch/unsloth_.py
+++ b/src/axolotl/monkeypatch/unsloth_.py
@@ -1,18 +1,20 @@
"""module for patching with unsloth optimizations"""
import inspect
-import logging
import re
import types
from typing import Tuple
+import torch
+from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
+from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
-LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
+LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
@@ -97,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
-def integrate_cross_entropy_loss_patch():
- forward = get_forward_code()
- LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
- forward, _ = detab_code(forward)
- assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
+def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
+ if model_type == "llama":
+ forward = get_forward_code()
+ LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
+ forward, _ = detab_code(forward)
+ assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
- forward = forward.replace(
- "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
- )
- forward = forward.replace(
- "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
- "",
- )
- forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
- forward = forward.replace(
- "def forward(",
- "def fast_cross_entropy_loss_forward(",
- 1,
- )
+ forward = forward.replace(
+ "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
+ )
+ forward = forward.replace(
+ "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
+ "",
+ )
+ forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
+ forward = forward.replace(
+ "def forward(",
+ "def fast_cross_entropy_loss_forward(",
+ 1,
+ )
- # load imports necessary
- import transformers.models.llama.modeling_llama
+ # load imports necessary
+ import transformers.models.llama.modeling_llama
- items_to_import = []
- for item in dir(transformers.models.llama.modeling_llama):
- if item in forward:
- items_to_import.append(item)
+ items_to_import = []
+ for item in dir(transformers.models.llama.modeling_llama):
+ if item in forward:
+ items_to_import.append(item)
- exec( # pylint: disable=exec-used # nosec B102
- "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
- globals(),
- )
+ exec( # pylint: disable=exec-used # nosec B102
+ "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
+ globals(),
+ )
- exec( # pylint: disable=exec-used # nosec B102
- "from transformers.models.llama.modeling_llama import ("
- + ", ".join(x for x in items_to_import)
- + ")",
- globals(),
- )
- exec(forward, globals()) # pylint: disable=exec-used # nosec B102
- print("patching unsloth fast_cross_entropy_loss")
- LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
+ exec( # pylint: disable=exec-used # nosec B102
+ "from transformers.models.llama.modeling_llama import ("
+ + ", ".join(x for x in items_to_import)
+ + ")",
+ globals(),
+ )
+ exec(forward, globals()) # pylint: disable=exec-used # nosec B102
+ LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
+ LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
+ else:
+ raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
@@ -179,12 +184,30 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
- print("patching unsloth attn lora")
+ LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
+def integrate_rope_embeddings():
+ import transformers.models.llama.modeling_llama
+ from unsloth.kernels.rope_embedding import fast_rope_embedding
+
+ def apply_rotary_pos_emb( # pylint: disable=unused-argument
+ q, # pylint: disable=invalid-name
+ k, # pylint: disable=invalid-name
+ cos,
+ sin,
+ position_ids=None,
+ unsqueeze_dim=1,
+ ):
+ return fast_rope_embedding(q, k, cos, sin)
+
+ LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
+ transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
+
+
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
@@ -217,7 +240,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
- logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
+ LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
@@ -243,9 +266,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
- logging.warning(
- "unable to apply unsloth lora qkv patch to layer %d", idx
- )
+ LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
@@ -264,6 +285,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
- logging.warning(
+ LOG.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)
+
+
+def patch_unsloth_layernorm():
+ try:
+ import transformers.models.llama.modeling_llama
+ from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
+
+ class LlamaRMSNorm(nn.Module):
+ """LlamaRMSNorm"""
+
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ return Fast_RMS_Layernorm.apply(
+ hidden_states, self.weight, self.variance_epsilon, False
+ )
+
+ LOG.info("patching with unsloth.kernels.rms_layernorm")
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
+ except ImportError:
+ LOG.warning("missing unsloth library")
diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py
new file mode 100644
index 000000000..4f2f14098
--- /dev/null
+++ b/src/axolotl/prompt_strategies/dpo/chat_template.py
@@ -0,0 +1,78 @@
+"""
+DPO prompt strategies for using tokenizer chat templates.
+"""
+
+from axolotl.utils.chat_templates import chat_templates
+
+
+def default(
+ cfg, dataset_idx=0, **kwargs
+): # pylint: disable=possibly-unused-variable,unused-argument
+ ds_cfg = cfg["datasets"][dataset_idx]
+ chat_template_str = chat_templates(cfg.chat_template)
+
+ field_messages = ds_cfg.get("field_messages", "messages")
+ field_chosen = ds_cfg.get("field_chosen", "chosen")
+ field_rejected = ds_cfg.get("field_rejected", "rejected")
+ field_message_role = ds_cfg.get("message_field_role", "role")
+ field_message_content = ds_cfg.get("message_field_content", "content")
+ role_map_inv = ds_cfg.get(
+ "roles",
+ {
+ "user": ["user"],
+ "assistant": ["assistant"],
+ "system": ["system"],
+ },
+ )
+ role_map = {}
+ for target, sources in role_map_inv.items():
+ for source in sources:
+ role_map[source] = target
+
+ def transform_fn(sample, tokenizer=None):
+ messages = sample[field_messages]
+ messages = [
+ {
+ "role": role_map[m[field_message_role]],
+ "content": m[field_message_content],
+ }
+ for m in messages
+ ]
+ chosen = {
+ "role": role_map[sample[field_chosen][field_message_role]],
+ "content": sample[field_chosen][field_message_content],
+ }
+ rejected = {
+ "role": role_map[sample[field_rejected][field_message_role]],
+ "content": sample[field_rejected][field_message_content],
+ }
+
+ result = {}
+ result["prompt"] = tokenizer.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ chat_template=chat_template_str,
+ tokenize=False,
+ )
+
+ result["chosen"] = tokenizer.apply_chat_template(
+ [chosen],
+ add_generation_prompt=False,
+ chat_template=chat_template_str,
+ tokenize=False,
+ )
+ chosen_strip_index = result["chosen"].find(chosen["content"])
+ result["chosen"] = result["chosen"][chosen_strip_index:]
+
+ result["rejected"] = tokenizer.apply_chat_template(
+ [rejected],
+ add_generation_prompt=False,
+ chat_template=chat_template_str,
+ tokenize=False,
+ )
+ rejected_strip_index = result["rejected"].find(rejected["content"])
+ result["rejected"] = result["rejected"][rejected_strip_index:]
+
+ return result
+
+ return transform_fn
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 99a9b0ba9..5ba5aed56 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -19,6 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
+from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
@@ -52,6 +53,15 @@ class TrainDatasetMeta:
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
+ # enable expandable segments for cuda allocation to improve VRAM usage
+ torch_version = torch.__version__.split(".")
+ torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
+ if torch_major == 2 and torch_minor >= 2:
+ if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
+ os.environ[
+ "PYTORCH_CUDA_ALLOC_CONF"
+ ] = "expandable_segments:True,roundup_power2_divisions:16"
+
# load the tokenizer first
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -114,6 +124,13 @@ def train(
total_num_steps,
)
+ if cfg.fix_untrained_tokens:
+ fix_untrained_tokens(model, tokenizer, train_dataset)
+ if cfg.local_rank == 0:
+ model.save_pretrained(
+ str(Path(cfg.output_dir)), safe_serialization=safe_serialization
+ )
+
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index 8845abe1b..945f8e018 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -9,6 +9,7 @@ import os
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
+from importlib.metadata import version
from pydantic import (
BaseModel,
Field,
@@ -84,6 +85,7 @@ class PretrainingDataset(BaseModel):
split: Optional[str] = "train"
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
+ trust_remote_code: Optional[bool] = False
class UserDefinedPrompterType(BaseModel):
@@ -125,6 +127,8 @@ class SFTDataset(BaseModel):
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
+ trust_remote_code: Optional[bool] = False
+
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
@@ -165,6 +169,7 @@ class KTODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
+ trust_remote_code: Optional[bool] = False
class RLType(str, Enum):
@@ -350,7 +355,16 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
- Union[OptimizerNames, Literal["lion_pytorch"]]
+ Union[
+ OptimizerNames,
+ Literal[
+ "lion_pytorch",
+ "optimi_adamw",
+ "ao_adamw_4bit",
+ "ao_adamw_8bit",
+ "ao_adamw_fp8",
+ ],
+ ]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
@@ -513,6 +527,8 @@ class AxolotlInputConfig(
dataloader_prefetch_factor: Optional[int] = None
dataloader_drop_last: Optional[bool] = None
+ accelerator_config: Optional[Dict[str, Any]] = None
+
remove_unused_columns: Optional[bool] = None
push_dataset_to_hub: Optional[str] = None
@@ -599,6 +615,8 @@ class AxolotlInputConfig(
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None
unsloth_lora_o: Optional[bool] = None
+ unsloth_rms_norm: Optional[bool] = None
+ unsloth_rope: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
@@ -611,6 +629,9 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None
+ torch_compile_mode: Optional[
+ Literal["default", "reduce-overhead", "max-autotune"]
+ ] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
@@ -651,6 +672,8 @@ class AxolotlInputConfig(
] = None
default_system_message: Optional[str] = None
+ fix_untrained_tokens: Optional[bool] = None
+
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
@@ -716,6 +739,24 @@ class AxolotlInputConfig(
)
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_pretraining_split_batches_accelerate(cls, data):
+ # alternatively set ACCELERATE_SPLIT_BATCHES=False
+ if data.get("pretraining_dataset"):
+ accelerator_config = data.get("accelerator_config", {})
+ if not accelerator_config:
+ data["accelerator_config"] = {
+ "split_batches": False,
+ "dispatch_batches": False,
+ }
+ else:
+ if accelerator_config.get("split_batches") is None:
+ data["accelerator_config"]["split_batches"] = False
+ if accelerator_config.get("dispatch_batches") is None:
+ data["accelerator_config"]["dispatch_batches"] = False
+ return data
+
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):
@@ -834,7 +875,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
- not self.optimizer or "adamw" not in self.optimizer.value
+ not self.optimizer or "adamw" not in str(self.optimizer).lower()
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@@ -1126,6 +1167,55 @@ class AxolotlInputConfig(
raise ValueError("either datasets or pretraining_dataset is required")
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_xentropy_patch_conflicts(cls, data):
+ if data.get("flash_attn_cross_entropy") and data.get(
+ "unsloth_cross_entropy_loss"
+ ):
+ raise ValueError(
+ "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_qlora_unsloth(cls, data):
+ if (
+ data.get("unsloth_lora_mlp")
+ or data.get("unsloth_lora_qkv")
+ or data.get("unsloth_lora_o")
+ ):
+ if data.get("adapter") == "lora" or data.get("load_in_8bit"):
+ raise ValueError(
+ "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_unsloth_xformers_version(cls, data):
+ if (
+ data.get("unsloth_lora_mlp")
+ or data.get("unsloth_lora_qkv")
+ or data.get("unsloth_lora_o")
+ ):
+ xformers_version = version("xformers")
+ if xformers_version == "0.0.27":
+ raise ValueError(
+ "xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_torch_compile_deepspeed(cls, data):
+ if data.get("deepspeed") and data.get("torch_compile"):
+ raise ValueError(
+ "torch_compile should be set within your deepspeed config file"
+ )
+ return data
+
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
@@ -1177,3 +1267,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.")
return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_multigpu_unsloth(cls, data):
+ if (
+ data.get("unsloth_lora_mlp")
+ or data.get("unsloth_lora_qkv")
+ or data.get("unsloth_lora_o")
+ ):
+ capabilities = data.get("capabilities")
+ if capabilities and capabilities.get("n_gpu", 0) > 1:
+ raise ValueError(
+ "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
+ )
+ return data
diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py
index 7416ca28b..d0324e1eb 100644
--- a/src/axolotl/utils/data/rl.py
+++ b/src/axolotl/utils/data/rl.py
@@ -1,4 +1,5 @@
"""data handling specific to DPO"""
+
import inspect
import logging
from functools import partial
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index f9ae2f0ce..0923bb826 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -1,7 +1,7 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
-
+import gc
import logging
import math
import os
@@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
"Please make sure to point to a GPTQ model."
)
- if not cfg.gptq and quant_config_exists:
+ if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
@@ -347,6 +347,31 @@ def load_model(
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
+
+ if cfg.is_llama_derived_model:
+ from axolotl.monkeypatch.llama_attn_hijack_flash import (
+ patch_llama_cross_entropy,
+ patch_llama_rms_norm,
+ )
+
+ if cfg.flash_attn_cross_entropy:
+ patch_llama_cross_entropy()
+ if cfg.flash_attn_rms_norm:
+ patch_llama_rms_norm()
+ elif cfg.unsloth_rms_norm:
+ from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
+
+ patch_unsloth_layernorm()
+ if cfg.unsloth_cross_entropy_loss:
+ from axolotl.monkeypatch.unsloth_ import (
+ integrate_cross_entropy_loss_patch,
+ )
+
+ integrate_cross_entropy_loss_patch(model_type="llama")
+ if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
+ from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
+
+ patch_self_attn_lora()
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
@@ -371,6 +396,12 @@ def load_model(
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
+ elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
+ replace_llama_attn_with_flash_attn(
+ packed=False,
+ cross_entropy=cfg.flash_attn_cross_entropy,
+ rms_norm=cfg.flash_attn_rms_norm,
+ )
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
@@ -393,7 +424,7 @@ def load_model(
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
- integrate_cross_entropy_loss_patch()
+ integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -401,23 +432,12 @@ def load_model(
patch_self_attn_lora()
# Modify mistral derived models
- if (
- cfg.model_config_type == "mistral"
- and cfg.flash_attention
- and cfg.sample_packing
- ):
+ if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
- replace_mistral_attn_with_flash_attn,
+ patch_mistral_cross_entropy,
)
- LOG.info("patching mistral with flash attention")
- replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
-
- if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
- from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
-
- LOG.info("patching _expand_mask")
- hijack_expand_mask()
+ patch_mistral_cross_entropy()
model_kwargs: Dict[str, Any] = {}
@@ -599,9 +619,12 @@ def load_model(
and not cfg.trust_remote_code
and not cfg.gptq
):
- from transformers import LlamaForCausalLM
+ if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
+ skip_move_to_device = True
+ if "device_map" in model_kwargs:
+ del model_kwargs["device_map"]
- model = LlamaForCausalLM.from_pretrained(
+ model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
@@ -634,7 +657,11 @@ def load_model(
base_model,
**model_kwargs,
)
- elif model_type and not cfg.trust_remote_code:
+ elif (
+ model_type
+ and model_type != "AutoModelForCausalLM"
+ and not cfg.trust_remote_code
+ ):
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
@@ -675,6 +702,7 @@ def load_model(
)
else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
+ # disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
@@ -849,6 +877,15 @@ def load_model(
integrate_lora_patch(model, cfg)
+ if cfg.unsloth_rope:
+ from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
+
+ integrate_rope_embeddings()
+
+ for _ in range(3):
+ gc.collect()
+ torch.cuda.empty_cache()
+
# TODO resume_from_checkpoint handling
return model, lora_config
diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index 845296b7a..f353aebec 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
- for token in tokenizer.encode(tokens)
+ for token in tokenizer.encode(tokens, add_special_tokens=False)
]
return colored_tokens
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index a16baaae0..65c2d424e 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
- if (
- cfg.is_mistral_derived_model and cfg.flash_attention
- ) or cfg.model_config_type == "mamba":
+ if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py
new file mode 100644
index 000000000..0991bdd74
--- /dev/null
+++ b/tests/e2e/patched/test_fa_xentropy.py
@@ -0,0 +1,87 @@
+"""
+E2E tests for lora llama
+"""
+
+import logging
+import os
+import unittest
+from importlib import reload
+from pathlib import Path
+
+import pytest
+from transformers.utils import is_torch_bf16_gpu_available
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+from ..utils import with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+@pytest.fixture(autouse=True)
+def reload_transformers():
+ import transformers.models.llama.modeling_llama
+
+ yield
+ reload(transformers.models.llama.modeling_llama)
+
+
+class TestFAXentropyLlama(unittest.TestCase):
+ """
+ Test case for Llama models using LoRA w multipack
+ """
+
+ @with_temp_dir
+ def test_lora_packing_fa_cross_entropy(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "tokenizer_type": "LlamaTokenizer",
+ "sequence_len": 1024,
+ "sample_packing": True,
+ "flash_attention": True,
+ "flash_attn_cross_entropy": True,
+ "load_in_8bit": True,
+ "adapter": "lora",
+ "lora_r": 32,
+ "lora_alpha": 64,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.2,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 8,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py
index eecd1b3c1..170c37fd6 100644
--- a/tests/e2e/patched/test_model_patches.py
+++ b/tests/e2e/patched/test_model_patches.py
@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
import unittest
+import transformers
+
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
- model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
+ load_model(cfg, tokenizer, inference=cli_args.inference)
assert (
- "axolotl.monkeypatch.mistral_attn_hijack_flash"
- in model.model.layers[0].self_attn.forward.__module__
+ "torch.jit"
+ in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
)
diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py
new file mode 100644
index 000000000..62fb63c47
--- /dev/null
+++ b/tests/e2e/test_llama_pretrain.py
@@ -0,0 +1,67 @@
+"""
+E2E tests for llama pretrain
+"""
+
+import logging
+import os
+import unittest
+from pathlib import Path
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+from .utils import with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestPretrainLlama(unittest.TestCase):
+ """
+ Test case for Llama models w pretraining
+ """
+
+ @with_temp_dir
+ def test_pretrain_w_sample_packing(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "tokenizer_type": "LlamaTokenizer",
+ "flash_attention": True,
+ "sequence_len": 1024,
+ "sample_packing": True,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "pretraining_dataset": [
+ {
+ "path": "allenai/c4",
+ "name": "en",
+ "type": "pretrain",
+ }
+ ],
+ "max_steps": 5,
+ "num_epochs": 1,
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "val_set_size": 0.0,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "save_safetensors": True,
+ "bf16": "auto",
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "model.safetensors").exists()
diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py
index c79652bef..4c6fdaaa9 100644
--- a/tests/e2e/test_lora_llama.py
+++ b/tests/e2e/test_lora_llama.py
@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
- "lora_r": 32,
- "lora_alpha": 64,
+ "lora_r": 8,
+ "lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
"type": "alpaca",
},
],
- "num_epochs": 2,
+ "num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py
new file mode 100644
index 000000000..119dd3d7c
--- /dev/null
+++ b/tests/e2e/test_optimizers.py
@@ -0,0 +1,67 @@
+"""
+E2E tests for custom optimizers using Llama
+"""
+
+import logging
+import os
+import unittest
+from pathlib import Path
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+from .utils import with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestCustomOptimizers(unittest.TestCase):
+ """
+ Test case for Llama models using LoRA
+ """
+
+ @with_temp_dir
+ def test_optimi_adamw(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "tokenizer_type": "LlamaTokenizer",
+ "sequence_len": 1024,
+ "load_in_8bit": True,
+ "adapter": "lora",
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 8,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "optimi_adamw",
+ "lr_scheduler": "cosine",
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py
new file mode 100644
index 000000000..cca48b1cf
--- /dev/null
+++ b/tests/prompt_strategies/test_dpo_chat_templates.py
@@ -0,0 +1,156 @@
+"""
+tests for chat_template prompt strategy
+"""
+
+import unittest
+
+import pytest
+from datasets import Dataset
+from transformers import AutoTokenizer
+
+from axolotl.prompt_strategies.dpo.chat_template import default
+from axolotl.utils.dict import DictDefault
+
+
+@pytest.fixture(name="assistant_dataset")
+def fixture_assistant_dataset():
+ # pylint: disable=duplicate-code
+ return Dataset.from_list(
+ [
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello",
+ },
+ {
+ "role": "assistant",
+ "content": "hello",
+ },
+ {
+ "role": "user",
+ "content": "goodbye",
+ },
+ ],
+ "chosen": {
+ "role": "assistant",
+ "content": "goodbye",
+ },
+ "rejected": {
+ "role": "assistant",
+ "content": "party on",
+ },
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="custom_assistant_dataset")
+def fixture_custom_assistant_dataset():
+ # pylint: disable=duplicate-code
+ return Dataset.from_list(
+ [
+ {
+ "conversation": [
+ {
+ "speaker": "human",
+ "text": "hello",
+ },
+ {
+ "speaker": "agent",
+ "text": "hello",
+ },
+ {
+ "speaker": "human",
+ "text": "goodbye",
+ },
+ ],
+ "better": {
+ "speaker": "agent",
+ "text": "goodbye",
+ },
+ "worse": {
+ "speaker": "agent",
+ "text": "party on",
+ },
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="llama3_tokenizer")
+def fixture_llama3_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
+ tokenizer.eos_token = "<|eot_id|>"
+
+ return tokenizer
+
+
+class TestAssistantDPOChatTemplateLlama3:
+ """
+ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
+ """
+
+ def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
+ # pylint: disable=duplicate-code
+ transform_fn = default(
+ DictDefault(
+ {
+ "chat_template": "llama3",
+ "datasets": [
+ {
+ "chat_template": "llama3",
+ }
+ ],
+ }
+ )
+ )
+ result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
+ assert result["prompt"] == (
+ "<|begin_of_text|>"
+ + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ + "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert result["chosen"] == "goodbye<|eot_id|>"
+ assert result["rejected"] == "party on<|eot_id|>"
+
+ def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
+ # pylint: disable=duplicate-code
+ transform_fn = default(
+ DictDefault(
+ {
+ "chat_template": "llama3",
+ "datasets": [
+ {
+ "chat_template": "llama3",
+ "field_messages": "conversation",
+ "field_chosen": "better",
+ "field_rejected": "worse",
+ "message_field_role": "speaker",
+ "message_field_content": "text",
+ "roles": {
+ "user": ["human"],
+ "assistant": ["agent"],
+ "system": ["sys"],
+ },
+ }
+ ],
+ }
+ )
+ )
+ result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
+ assert result["prompt"] == (
+ "<|begin_of_text|>"
+ + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ + "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert result["chosen"] == "goodbye<|eot_id|>"
+ assert result["rejected"] == "party on<|eot_id|>"
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py
index fb623a43d..5d517585f 100644
--- a/tests/test_packed_pretraining.py
+++ b/tests/test_packed_pretraining.py
@@ -24,7 +24,7 @@ class TestPretrainingPacking(unittest.TestCase):
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
- "c4",
+ "allenai/c4",
"en",
streaming=True,
)["train"]
@@ -33,7 +33,7 @@ class TestPretrainingPacking(unittest.TestCase):
{
"pretraining_dataset": [
{
- "path": "c4",
+ "path": "allenai/c4",
"name": "en",
"type": "pretrain",
}