diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index d215ea44c..f8eaff270 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -37,6 +37,11 @@ jobs: python_version: "3.11" pytorch: 2.3.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "121" + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.3.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8bced628d..4969de75d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,7 +19,6 @@ jobs: pytorch: 2.1.2 axolotl_extras: axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118" - is_latest: true - cuda: 121 cuda_version: 12.1.0 python_version: "3.10" @@ -33,8 +32,9 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.3.0 + pytorch: 2.3.1 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -80,7 +80,6 @@ jobs: python_version: "3.10" pytorch: 2.1.2 axolotl_extras: - is_latest: true - cuda: 121 cuda_version: 12.1.0 python_version: "3.10" @@ -94,8 +93,9 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.3.0 + pytorch: 2.3.1 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -136,7 +136,7 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.3.0 + pytorch: 2.3.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 6dc22b6bf..770954b85 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -18,7 +18,6 @@ jobs: pytorch: 2.1.2 axolotl_extras: axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118" - is_latest: true - cuda: 121 cuda_version: 12.1.0 python_version: "3.10" @@ -32,8 +31,9 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.3.0 + pytorch: 2.3.1 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -80,7 +80,6 @@ jobs: python_version: "3.10" pytorch: 2.1.2 axolotl_extras: - is_latest: true - cuda: 121 cuda_version: 12.1.0 python_version: "3.10" @@ -94,8 +93,9 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.3.0 + pytorch: 2.3.1 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2e2d0968d..1cee8cbcb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -57,6 +57,10 @@ jobs: run: | pytest --ignore=tests/e2e/ tests/ + - name: cleanup pip cache + run: | + find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + docker-e2e-tests: if: github.repository_owner == 'axolotl-ai-cloud' # this job needs to be run on self-hosted GPU runners... @@ -87,7 +91,7 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.3.0 + pytorch: 2.3.1 num_gpus: 1 steps: - name: Checkout @@ -99,7 +103,7 @@ jobs: - name: Install Modal run: | python -m pip install --upgrade pip - pip install modal jinja2 + pip install modal==0.63.64 jinja2 - name: Update env vars run: | echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV diff --git a/README.md b/README.md index 6b35702b3..fd293bd04 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ Features: - [Multipack](./docs/multipack.qmd) - [RLHF & DPO](./docs/rlhf.qmd) - [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd) + - [Unsloth](./docs/unsloth.qmd) - [Common Errors](#common-errors-) - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) - [Debugging Axolotl](#debugging-axolotl) @@ -333,7 +334,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. -See [these docs](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats. +See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats. ### Config diff --git a/_quarto.yml b/_quarto.yml index 009fa8056..6b2eed971 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -36,6 +36,7 @@ website: - docs/nccl.qmd - docs/mac.qmd - docs/multi-node.qmd + - docs/unsloth.qmd - section: "Dataset Formats" contents: docs/dataset-formats/* - section: "Reference" diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 96c312ddc..263f4a661 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -24,13 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \ # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image -RUN pip install pytest +RUN pip install -r requirements-tests.txt # fix so that git fetch/pull from remote works RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ diff --git a/cicd/cicd.sh b/cicd/cicd.sh index bc36458ab..180150ea2 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -2,5 +2,5 @@ set -e pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ -pytest /workspace/axolotl/tests/e2e/patched/ +pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/ diff --git a/docker/Dockerfile b/docker/Dockerfile index cdb6d177a..be58d0354 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image diff --git a/docs/torchao.qmd b/docs/torchao.qmd new file mode 100644 index 000000000..2dc9117fb --- /dev/null +++ b/docs/torchao.qmd @@ -0,0 +1,19 @@ +--- +title: "PyTorch ao" +description: "Custom data types and layouts for training and inference" +--- + +### Installation + +Stable Release from the PyTorch index + +```bash +pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124 +``` + + +Nightly release + +```bash +pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124 +``` diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd new file mode 100644 index 000000000..390609fd3 --- /dev/null +++ b/docs/unsloth.qmd @@ -0,0 +1,49 @@ +--- +title: "Unsloth" +description: "Hyper-optimized QLoRA finetuning for single GPUs" +--- + +### Overview + +Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over +standard industry baselines. + + +### Installation + +The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up +to date libraries. + +```bash +pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git" +pip install --no-deps --force-reinstall xformers==0.0.26.post1 +``` + +### Using unsloth w Axolotl + +Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains. + +Our unsloth integration is currently limited to the following model architectures: + - llama + +These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning +```yaml +unsloth_lora_mlp: true +unsloth_lora_qkv: true +unsloth_lora_o: true +``` + +These options are composable and can be used with multi-gpu finetuning +``` +unsloth_cross_entropy_loss: true +unsloth_rms_norm: true +unsloth_rope: true +``` + +### Limitations + +- Single GPU only; e.g. no multi-gpu support +- No deepspeed or FSDP support (requires multi-gpu) +- LoRA + QLoRA support only. No full fine tunes or fp8 support. +- Limited model architecture support. Llama, Phi, Gemma, Mistral only +- No MoE support. diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 94477eb19..3fcc4d2a9 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -171,7 +171,7 @@ }, "outputs": [], "source": [ - "# Buy using the ! the comand will be executed as a bash command\n", + "# By using the ! the comand will be executed as a bash command\n", "!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" ] }, @@ -188,7 +188,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Buy using the ! the comand will be executed as a bash command\n", + "# By using the ! the comand will be executed as a bash command\n", "!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n", " --qlora_model_dir=\"./qlora-out\" --gradio" ] diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index a36fd740e..908ef6e03 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -1,4 +1,4 @@ -base_model: meta-llama/Meta-Llama-3-8B +base_model: NousResearch/Meta-Llama-3-8B model_type: LlamaForCausalLM tokenizer_type: AutoTokenizer diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml new file mode 100644 index 000000000..14febb810 --- /dev/null +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -0,0 +1,81 @@ +base_model: meta-llama/Meta-Llama-3-8B-Instruct +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: llama3 +rl: dpo +datasets: + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + chat_template: llama3 + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_field_role: role + message_field_content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index 754c9ad5c..21d32604c 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -1,4 +1,4 @@ -base_model: meta-llama/Meta-Llama-3-8B-Instruct +base_model: NousResearch/Meta-Llama-3-8B-Instruct model_type: LlamaForCausalLM tokenizer_type: AutoTokenizer diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml index dfc881faf..a20a529f5 100644 --- a/examples/llama-3/lora-8b.yml +++ b/examples/llama-3/lora-8b.yml @@ -1,4 +1,4 @@ -base_model: meta-llama/Meta-Llama-3-8B +base_model: NousResearch/Meta-Llama-3-8B model_type: LlamaForCausalLM tokenizer_type: AutoTokenizer @@ -15,6 +15,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true +eval_sample_packing: false pad_to_sequence_len: true adapter: lora diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml index 44120d938..079c9cad0 100644 --- a/examples/llama-3/qlora.yml +++ b/examples/llama-3/qlora.yml @@ -1,4 +1,4 @@ -base_model: meta-llama/Meta-Llama-3-8B +base_model: NousResearch/Meta-Llama-3-8B model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer diff --git a/requirements-tests.txt b/requirements-tests.txt index e079f8a60..9cda381d0 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -1 +1,2 @@ pytest +pytest-xdist diff --git a/requirements.txt b/requirements.txt index c8d168734..b2aac0dd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.11.1 -transformers==4.42.3 +transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf tokenizers==0.19.1 bitsandbytes==0.43.1 accelerate==0.32.0 @@ -12,11 +12,11 @@ fire PyYAML>=6.0 requests datasets==2.19.1 -flash-attn==2.5.8 +flash-attn==2.6.1 sentencepiece wandb einops -xformers==0.0.26.post1 +xformers==0.0.27 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index c7b4e15de..9e6f34ad8 100644 --- a/setup.py +++ b/setup.py @@ -29,9 +29,10 @@ def parse_requirements(): _install_requires.append(line) try: + xformers_version = [req for req in _install_requires if "xformers" in req][0] if "Darwin" in platform.system(): # don't install xformers on MacOS - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) else: # detect the version of torch already installed # and set it so dependencies don't clobber the torch version @@ -49,12 +50,14 @@ def parse_requirements(): raise ValueError("Invalid version format") if (major, minor) >= (2, 3): - pass + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.26.post1") elif (major, minor) >= (2, 2): - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") else: - _install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) + _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.23.post1") except PackageNotFoundError: @@ -77,10 +80,10 @@ setup( dependency_links=dependency_links, extras_require={ "flash-attn": [ - "flash-attn==2.5.8", + "flash-attn==2.6.1", ], "fused-dense-lib": [ - "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib", + "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib", ], "deepspeed": [ "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", @@ -101,5 +104,11 @@ setup( "galore": [ "galore_torch", ], + "optimizers": [ + "galore_torch", + "lion-pytorch==0.1.2", + "lomo-optim==0.1.1", + "torch-optimi==0.2.1", + ], }, ) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 7ec3f524a..5966d5931 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): cfg, capabilities={ "bf16": is_torch_bf16_gpu_available(), - "n_gpu": os.environ.get("WORLD_SIZE", 1), + "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "compute_capability": gpu_version, }, ) diff --git a/src/axolotl/core/tokenizer_utils.py b/src/axolotl/core/tokenizer_utils.py new file mode 100644 index 000000000..53c44a75c --- /dev/null +++ b/src/axolotl/core/tokenizer_utils.py @@ -0,0 +1,150 @@ +""" +helper functions for fixing the embeddings/tokenizer +""" + +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import itertools + +import numpy as np +import torch + + +@torch.inference_mode +def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16): + """ + Many of the newer models have reserved tokens that are not trained. + """ + embedding_matrix = model.get_input_embeddings().weight + lm_head_matrix = model.get_output_embeddings().weight + + # Get untrained tokens + indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps + where_untrained = torch.where(indicator_untrained)[0] + n_untrained = where_untrained.shape[0] + n_trained = embedding_matrix.shape[0] - n_untrained + + # Get set and actual tokens + where_untrained = where_untrained.tolist() + if len(where_untrained) == 0: + return False + + # Remove untrained indices where it's longer + + where_untrained_set = frozenset(where_untrained) + actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained) + # Remove None items in actual_bad_tokens + actual_bad_tokens = [x for x in actual_bad_tokens if x is not None] + + # Check if tokenizer and training datasets have bad tokens + if_bad_first = False + if_bad_second = False + # Check tokenizer's chat template for any untrained tokens + chat_template = getattr(tokenizer, "chat_template", None) + if chat_template is not None: + if_bad_first = any(x in chat_template for x in actual_bad_tokens) + + # Check the first 250, last 250 input_ids + size_dataset = len(train_dataset) + size = min(size_dataset, 250) + for j in range(size): + input_ids = train_dataset[j] + if "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + if_bad = any(item in where_untrained_set for item in input_ids) + if if_bad: + if_bad_second = True + break + + # Check last 250 + if not if_bad_second: + left = max(size_dataset - 250, 0) + for j in range(left, size_dataset): + input_ids = train_dataset[j] + if "input_ids" in input_ids: + input_ids = input_ids["input_ids"] + if_bad = any(item in where_untrained_set for item in input_ids) + if if_bad: + if_bad_second = True + break + + # Check if bad tokens exists! + if not if_bad_first and not if_bad_second: + return False + + # Count all the possible bad tokens + final_counts = np.zeros( + max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64 + ) + + def mapping(examples): + input_ids = examples["input_ids"] + counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32) + np.add.at(final_counts, counter, 1) + + train_dataset.map(mapping, batched=True, desc="Counting untrained tokens") + + # Get sum of all items + sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0) + sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0) + + # Remove bad tokens + sum_embedding -= torch.sum( + embedding_matrix[where_untrained], dtype=torch.float32, axis=0 + ) + sum_lm_head -= torch.sum( + lm_head_matrix[where_untrained], dtype=torch.float32, axis=0 + ) + + # Find correct average by dividing by sum of trained tokens + mean_embedding = sum_embedding / n_trained + mean_lm_head = sum_lm_head / n_trained + + # Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen + scaling = final_counts[where_untrained] / max(final_counts.max(), 1) + scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1) + mean_embedding = ( + mean_embedding.repeat( + ( + n_untrained, + 1, + ) + ) + * scaling + ) + mean_lm_head = ( + mean_lm_head.repeat( + ( + n_untrained, + 1, + ) + ) + * scaling + ) + where_null = scaling.ravel() == 0 + mean_embedding[where_null] = 0 + mean_lm_head[where_null] = 0 + + # Set them to the mean + embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype) + lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype) + + # Clean up + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + + return True diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ec175454e..616b1d4eb 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -226,6 +226,12 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "whether to use sequential sampling for curriculum learning"}, ) + alternate_optimizer: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate optimizer to the HF trainer" + }, + ) @dataclass @@ -284,26 +290,91 @@ class AxolotlTrainer(Trainer): if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.torch_compile: + torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access + 256 + ) + model = torch.compile( + model, + backend=self.args.torch_compile_backend, + mode=self.args.torch_compile_mode, + ) + return super()._wrap_model(model, training=training, dataloader=dataloader) + def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: + if ( + self.args.loraplus_lr_ratio is None + and self.args.alternate_optimizer + not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"] + ): return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: # pylint: disable=access-member-before-definition + decay_parameters = self.get_decay_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args, opt_model, ) - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, - ) + if self.args.loraplus_lr_ratio is not None: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", None + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + optimizer_kwargs, + loraplus_lr_ratio, + loraplus_lr_embedding, + ) + elif self.args.alternate_optimizer == "optimi_adamw": + from optimi import AdamW + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW( + optimizer_grouped_parameters, foreach=False, **optimizer_kwargs + ) + ) + elif self.args.alternate_optimizer == "ao_adamw_4bit": + from torchao.prototype.low_bit_optim import AdamW4bit + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "ao_adamw_8bit": + from torchao.prototype.low_bit_optim import AdamW8bit + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "ao_adamw_fp8": + from torchao.prototype.low_bit_optim import AdamWFp8 + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) + ) if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init @@ -1235,6 +1306,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "torch_compile_backend" ] = self.cfg.torch_compile_backend + if self.cfg.torch_compile_mode: + training_arguments_kwargs[ + "torch_compile_mode" + ] = self.cfg.torch_compile_mode # DDP Config if self.cfg.ddp_timeout: @@ -1396,6 +1471,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} + if self.cfg.optimizer in [ + "optimi_adamw", + "ao_adamw_4bit", + "ao_adamw_8bit", + "ao_adamw_fp8", + ]: + # Set default so transformers doesn't throw + training_arguments_kwargs["optim"] = "adamw_hf" + training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer + if self.cfg.optimizer == "lion_pytorch": from lion_pytorch import Lion @@ -1424,6 +1509,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): sys.path.append(self.cfg.torchdistx_path) importlib.import_module("torchdistx") + if self.cfg.accelerator_config: + training_arguments_kwargs[ + "accelerator_config" + ] = self.cfg.accelerator_config + training_args = ( AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, @@ -1621,6 +1711,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_cls = AxolotlDPOConfig if self.cfg.rpo_alpha is not None: training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha @@ -1688,8 +1779,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["generate_during_eval"] = True - if self.cfg.rl == "dpo": - dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] diff --git a/src/axolotl/integrations/__init__.py b/src/axolotl/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6d7a23f0d..4c3571ea4 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model): set_module_name(model, name, qkv) +def patch_llama_cross_entropy(): + from flash_attn.losses.cross_entropy import CrossEntropyLoss + + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + + +def patch_llama_rms_norm(): + try: + from flash_attn.ops.rms_norm import RMSNorm + + class LlamaRMSNorm(RMSNorm): + """Patched LLamaRMSNorm""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps=eps) + + LOG.info("patching with flash_attn.ops.rms_norm") + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.warning( + "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" + ) + + def replace_llama_attn_with_flash_attn( packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, @@ -104,35 +131,11 @@ def replace_llama_attn_with_flash_attn( # skip only if explicitly disabled if cross_entropy: - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss - - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True - ) - except ImportError: - LOG.warning( - "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)" - ) + patch_llama_cross_entropy() # skip only if explicitly disabled if rms_norm: - try: - from flash_attn.ops.rms_norm import RMSNorm - - class LlamaRMSNorm(RMSNorm): - """Patched LLamaRMSNorm""" - - def __init__(self, hidden_size, eps=1e-6): - super().__init__(hidden_size, eps=eps) - - LOG.info("patching with flash_attn.ops.rms_norm") - transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm - except ImportError: - LOG.warning( - "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" - ) + patch_llama_rms_norm() class FusedAttention(LlamaAttention): diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index c5425dd52..1cbc4278b 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -2,6 +2,7 @@ # pylint: disable=duplicate-code import logging +from functools import partial from typing import List, Optional, Tuple, Union import torch @@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn( ) +def patch_mistral_cross_entropy(): + from flash_attn.losses.cross_entropy import CrossEntropyLoss + + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + + @torch.jit.script def _make_sliding_window_causal_mask( bsz: int, diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index e319596d0..a2ce0e64f 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -10,6 +10,8 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "llama", + "mistral", "mixtral", "qwen2", "qwen2_moe", @@ -24,12 +26,35 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ def patch_for_multipack(model_type, model_name=None): + if model_type == "gemmoe": + patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") + elif model_type == "deepseek_v2": + patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") + elif hasattr(transformers, "modeling_flash_attention_utils"): + transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + if model_type == "mixtral" and is_deepspeed_zero3_enabled(): + patch_mixtral_moe_forward_zero3() + return + + # retain for legacy if model_type == "mixtral": transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) if is_deepspeed_zero3_enabled(): patch_mixtral_moe_forward_zero3() + elif model_type == "llama": + if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"): + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "mistral": + if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"): + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) elif model_type == "qwen2": transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data @@ -58,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None): transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) - elif model_type == "gemmoe": - patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") - elif model_type == "jamba": - patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") - elif model_type == "deepseek_v2": - patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") def patch_remote(model_name, config_name, modeling_name): diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 6af3046e1..5b1f0061d 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -1,18 +1,20 @@ """module for patching with unsloth optimizations""" import inspect -import logging import re import types from typing import Tuple +import torch +from accelerate.logging import get_logger from peft import PeftModelForCausalLM +from torch import nn from transformers.models.llama.modeling_llama import ( LlamaFlashAttention2, LlamaForCausalLM, ) -LOG = logging.getLogger("axolotl.monkeypatch.unsloth") +LOG = get_logger("axolotl.monkeypatch.unsloth") ORIGINAL_CEL_CODE = """ if labels is not None: # Shift so that tokens < n predict n @@ -97,48 +99,51 @@ def check_self_attn_is_patchable() -> bool: return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv -def integrate_cross_entropy_loss_patch(): - forward = get_forward_code() - LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access - forward, _ = detab_code(forward) - assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" +def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: + if model_type == "llama": + forward = get_forward_code() + LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access + forward, _ = detab_code(forward) + assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" - forward = forward.replace( - "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" - ) - forward = forward.replace( - "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", - "", - ) - forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) - forward = forward.replace( - "def forward(", - "def fast_cross_entropy_loss_forward(", - 1, - ) + forward = forward.replace( + "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" + ) + forward = forward.replace( + "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", + "", + ) + forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) + forward = forward.replace( + "def forward(", + "def fast_cross_entropy_loss_forward(", + 1, + ) - # load imports necessary - import transformers.models.llama.modeling_llama + # load imports necessary + import transformers.models.llama.modeling_llama - items_to_import = [] - for item in dir(transformers.models.llama.modeling_llama): - if item in forward: - items_to_import.append(item) + items_to_import = [] + for item in dir(transformers.models.llama.modeling_llama): + if item in forward: + items_to_import.append(item) - exec( # pylint: disable=exec-used # nosec B102 - "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", - globals(), - ) + exec( # pylint: disable=exec-used # nosec B102 + "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", + globals(), + ) - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.models.llama.modeling_llama import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - print("patching unsloth fast_cross_entropy_loss") - LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.models.llama.modeling_llama import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(forward, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True) + LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 + else: + raise ValueError("Unsupported model type") def detab_code(code: str) -> Tuple[str, str]: @@ -179,12 +184,30 @@ def patch_self_attn_lora(): globals(), ) exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 - print("patching unsloth attn lora") + LOG.info("patching unsloth attn lora", main_process_only=True) LlamaFlashAttention2.forward = ( unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 ) +def integrate_rope_embeddings(): + import transformers.models.llama.modeling_llama + from unsloth.kernels.rope_embedding import fast_rope_embedding + + def apply_rotary_pos_emb( # pylint: disable=unused-argument + q, # pylint: disable=invalid-name + k, # pylint: disable=invalid-name + cos, + sin, + position_ids=None, + unsqueeze_dim=1, + ): + return fast_rope_embedding(q, k, cos, sin) + + LOG.info("patching unsloth RoPE embeddings", main_process_only=True) + transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb + + def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): if peft_model.base_model.config.model_type in ["llama", "mistral"]: from unsloth.kernels import apply_lora_mlp_swiglu @@ -217,7 +240,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): if is_mlp_lora and mlp_no_bias and mlp_not_dora: layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) else: - logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx) + LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx) def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): @@ -243,9 +266,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_qkv = apply_lora_qkv else: layer.self_attn.apply_qkv = original_apply_qkv - logging.warning( - "unable to apply unsloth lora qkv patch to layer %d", idx - ) + LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx) if cfg.unsloth_lora_o: layer_modules = [ getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] @@ -264,6 +285,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_o = apply_lora_o else: layer.self_attn.apply_o = original_apply_o - logging.warning( + LOG.warning( "unable to apply unsloth lora o_proj patch to layer %d", idx ) + + +def patch_unsloth_layernorm(): + try: + import transformers.models.llama.modeling_llama + from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm + + class LlamaRMSNorm(nn.Module): + """LlamaRMSNorm""" + + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return Fast_RMS_Layernorm.apply( + hidden_states, self.weight, self.variance_epsilon, False + ) + + LOG.info("patching with unsloth.kernels.rms_layernorm") + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.warning("missing unsloth library") diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py new file mode 100644 index 000000000..4f2f14098 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -0,0 +1,78 @@ +""" +DPO prompt strategies for using tokenizer chat templates. +""" + +from axolotl.utils.chat_templates import chat_templates + + +def default( + cfg, dataset_idx=0, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + ds_cfg = cfg["datasets"][dataset_idx] + chat_template_str = chat_templates(cfg.chat_template) + + field_messages = ds_cfg.get("field_messages", "messages") + field_chosen = ds_cfg.get("field_chosen", "chosen") + field_rejected = ds_cfg.get("field_rejected", "rejected") + field_message_role = ds_cfg.get("message_field_role", "role") + field_message_content = ds_cfg.get("message_field_content", "content") + role_map_inv = ds_cfg.get( + "roles", + { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ) + role_map = {} + for target, sources in role_map_inv.items(): + for source in sources: + role_map[source] = target + + def transform_fn(sample, tokenizer=None): + messages = sample[field_messages] + messages = [ + { + "role": role_map[m[field_message_role]], + "content": m[field_message_content], + } + for m in messages + ] + chosen = { + "role": role_map[sample[field_chosen][field_message_role]], + "content": sample[field_chosen][field_message_content], + } + rejected = { + "role": role_map[sample[field_rejected][field_message_role]], + "content": sample[field_rejected][field_message_content], + } + + result = {} + result["prompt"] = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=False, + ) + + result["chosen"] = tokenizer.apply_chat_template( + [chosen], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + ) + chosen_strip_index = result["chosen"].find(chosen["content"]) + result["chosen"] = result["chosen"][chosen_strip_index:] + + result["rejected"] = tokenizer.apply_chat_template( + [rejected], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + ) + rejected_strip_index = result["rejected"].find(rejected["content"]) + result["rejected"] = result["rejected"][rejected_strip_index:] + + return result + + return transform_fn diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 99a9b0ba9..5ba5aed56 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -19,6 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs +from axolotl.core.tokenizer_utils import fix_untrained_tokens from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except @@ -52,6 +53,15 @@ class TrainDatasetMeta: def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: + # enable expandable segments for cuda allocation to improve VRAM usage + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ[ + "PYTORCH_CUDA_ALLOC_CONF" + ] = "expandable_segments:True,roundup_power2_divisions:16" + # load the tokenizer first LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", @@ -114,6 +124,13 @@ def train( total_num_steps, ) + if cfg.fix_untrained_tokens: + fix_untrained_tokens(model, tokenizer, train_dataset) + if cfg.local_rank == 0: + model.save_pretrained( + str(Path(cfg.output_dir)), safe_serialization=safe_serialization + ) + # go ahead and presave, so we have the adapter config available to inspect if peft_config: LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 8845abe1b..945f8e018 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -9,6 +9,7 @@ import os from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union +from importlib.metadata import version from pydantic import ( BaseModel, Field, @@ -84,6 +85,7 @@ class PretrainingDataset(BaseModel): split: Optional[str] = "train" text_column: Optional[str] = "text" type: Optional[str] = "pretrain" + trust_remote_code: Optional[bool] = False class UserDefinedPrompterType(BaseModel): @@ -125,6 +127,8 @@ class SFTDataset(BaseModel): roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None + trust_remote_code: Optional[bool] = False + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" @@ -165,6 +169,7 @@ class KTODataset(BaseModel): split: Optional[str] = None type: Optional[Union[UserDefinedKTOType, str]] = None data_files: Optional[List[str]] = None + trust_remote_code: Optional[bool] = False class RLType(str, Enum): @@ -350,7 +355,16 @@ class HyperparametersConfig(BaseModel): learning_rate: Union[str, float] weight_decay: Optional[float] = 0.0 optimizer: Optional[ - Union[OptimizerNames, Literal["lion_pytorch"]] + Union[ + OptimizerNames, + Literal[ + "lion_pytorch", + "optimi_adamw", + "ao_adamw_4bit", + "ao_adamw_8bit", + "ao_adamw_fp8", + ], + ] ] = OptimizerNames.ADAMW_HF.value optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, metadata={"help": "Optional arguments to supply to optimizer."} @@ -513,6 +527,8 @@ class AxolotlInputConfig( dataloader_prefetch_factor: Optional[int] = None dataloader_drop_last: Optional[bool] = None + accelerator_config: Optional[Dict[str, Any]] = None + remove_unused_columns: Optional[bool] = None push_dataset_to_hub: Optional[str] = None @@ -599,6 +615,8 @@ class AxolotlInputConfig( unsloth_lora_mlp: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None unsloth_lora_o: Optional[bool] = None + unsloth_rms_norm: Optional[bool] = None + unsloth_rope: Optional[bool] = None deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None @@ -611,6 +629,9 @@ class AxolotlInputConfig( torch_compile: Optional[bool] = None torch_compile_backend: Optional[str] = None + torch_compile_mode: Optional[ + Literal["default", "reduce-overhead", "max-autotune"] + ] = None max_steps: Optional[int] = None warmup_steps: Optional[int] = None @@ -651,6 +672,8 @@ class AxolotlInputConfig( ] = None default_system_message: Optional[str] = None + fix_untrained_tokens: Optional[bool] = None + # INTERNALS - document for now, generally not set externally is_preprocess: Optional[bool] = None @@ -716,6 +739,24 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_pretraining_split_batches_accelerate(cls, data): + # alternatively set ACCELERATE_SPLIT_BATCHES=False + if data.get("pretraining_dataset"): + accelerator_config = data.get("accelerator_config", {}) + if not accelerator_config: + data["accelerator_config"] = { + "split_batches": False, + "dispatch_batches": False, + } + else: + if accelerator_config.get("split_batches") is None: + data["accelerator_config"]["split_batches"] = False + if accelerator_config.get("dispatch_batches") is None: + data["accelerator_config"]["dispatch_batches"] = False + return data + @model_validator(mode="before") @classmethod def check_gptq_w_revision(cls, data): @@ -834,7 +875,7 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_adamw_optimizer_params(self): if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and ( - not self.optimizer or "adamw" not in self.optimizer.value + not self.optimizer or "adamw" not in str(self.optimizer).lower() ): LOG.warning("adamw hyperparameters found, but no adamw optimizer set") return self @@ -1126,6 +1167,55 @@ class AxolotlInputConfig( raise ValueError("either datasets or pretraining_dataset is required") return data + @model_validator(mode="before") + @classmethod + def check_xentropy_patch_conflicts(cls, data): + if data.get("flash_attn_cross_entropy") and data.get( + "unsloth_cross_entropy_loss" + ): + raise ValueError( + "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_qlora_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + if data.get("adapter") == "lora" or data.get("load_in_8bit"): + raise ValueError( + "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_unsloth_xformers_version(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + xformers_version = version("xformers") + if xformers_version == "0.0.27": + raise ValueError( + "xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_torch_compile_deepspeed(cls, data): + if data.get("deepspeed") and data.get("torch_compile"): + raise ValueError( + "torch_compile should be set within your deepspeed config file" + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options""" @@ -1177,3 +1267,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data.get("deepspeed") and data.get("fsdp"): raise ValueError("deepspeed and fsdp cannot be used together.") return data + + @model_validator(mode="before") + @classmethod + def check_multigpu_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + capabilities = data.get("capabilities") + if capabilities and capabilities.get("n_gpu", 0) > 1: + raise ValueError( + "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training." + ) + return data diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 7416ca28b..d0324e1eb 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -1,4 +1,5 @@ """data handling specific to DPO""" + import inspect import logging from functools import partial diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index f9ae2f0ce..0923bb826 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,7 +1,7 @@ """Module for models and model loading""" # pylint: disable=too-many-lines - +import gc import logging import math import os @@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef "Please make sure to point to a GPTQ model." ) - if not cfg.gptq and quant_config_exists: + if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit: raise ValueError( "model_config.quantization_config is set but `gptq` flag is not. " "Please use the `gptq` flag to train quantized model or point to a non-quantized model." @@ -347,6 +347,31 @@ def load_model( and cfg.sample_packing ): patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) + + if cfg.is_llama_derived_model: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_llama_cross_entropy, + patch_llama_rms_norm, + ) + + if cfg.flash_attn_cross_entropy: + patch_llama_cross_entropy() + if cfg.flash_attn_rms_norm: + patch_llama_rms_norm() + elif cfg.unsloth_rms_norm: + from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm + + patch_unsloth_layernorm() + if cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import ( + integrate_cross_entropy_loss_patch, + ) + + integrate_cross_entropy_loss_patch(model_type="llama") + if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + + patch_self_attn_lora() elif cfg.is_llama_derived_model: # Modify all llama derived models in one block @@ -371,6 +396,12 @@ def load_model( rms_norm=cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) + elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm: + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + ) elif cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, @@ -393,7 +424,7 @@ def load_model( if cfg.unsloth_cross_entropy_loss: from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - integrate_cross_entropy_loss_patch() + integrate_cross_entropy_loss_patch(model_type="llama") if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora @@ -401,23 +432,12 @@ def load_model( patch_self_attn_lora() # Modify mistral derived models - if ( - cfg.model_config_type == "mistral" - and cfg.flash_attention - and cfg.sample_packing - ): + if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss: from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - replace_mistral_attn_with_flash_attn, + patch_mistral_cross_entropy, ) - LOG.info("patching mistral with flash attention") - replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - - if cfg.is_llama_derived_model and cfg.sample_packing and not inference: - from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask - - LOG.info("patching _expand_mask") - hijack_expand_mask() + patch_mistral_cross_entropy() model_kwargs: Dict[str, Any] = {} @@ -599,9 +619,12 @@ def load_model( and not cfg.trust_remote_code and not cfg.gptq ): - from transformers import LlamaForCausalLM + if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + skip_move_to_device = True + if "device_map" in model_kwargs: + del model_kwargs["device_map"] - model = LlamaForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( base_model, config=model_config, **model_kwargs, @@ -634,7 +657,11 @@ def load_model( base_model, **model_kwargs, ) - elif model_type and not cfg.trust_remote_code: + elif ( + model_type + and model_type != "AutoModelForCausalLM" + and not cfg.trust_remote_code + ): if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( base_model, @@ -675,6 +702,7 @@ def load_model( ) else: if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True if "device_map" in model_kwargs: del model_kwargs["device_map"] @@ -849,6 +877,15 @@ def load_model( integrate_lora_patch(model, cfg) + if cfg.unsloth_rope: + from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + + integrate_rope_embeddings() + + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + # TODO resume_from_checkpoint handling return model, lora_config diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 845296b7a..f353aebec 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): """Helper function to process and color tokens.""" colored_tokens = [ color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) - for token in tokenizer.encode(tokens) + for token in tokenizer.encode(tokens, add_special_tokens=False) ] return colored_tokens diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a16baaae0..65c2d424e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): max_input_len = np.max(get_dataset_lengths(train_dataset)) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - if ( - cfg.is_mistral_derived_model and cfg.flash_attention - ) or cfg.model_config_type == "mamba": + if cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py new file mode 100644 index 000000000..0991bdd74 --- /dev/null +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -0,0 +1,87 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from importlib import reload +from pathlib import Path + +import pytest +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +@pytest.fixture(autouse=True) +def reload_transformers(): + import transformers.models.llama.modeling_llama + + yield + reload(transformers.models.llama.modeling_llama) + + +class TestFAXentropyLlama(unittest.TestCase): + """ + Test case for Llama models using LoRA w multipack + """ + + @with_temp_dir + def test_lora_packing_fa_cross_entropy(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "flash_attn_cross_entropy": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.2, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index eecd1b3c1..170c37fd6 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu import unittest +import transformers + from axolotl.common.cli import TrainerCliArgs from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase): normalize_config(cfg) cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + load_model(cfg, tokenizer, inference=cli_args.inference) assert ( - "axolotl.monkeypatch.mistral_attn_hijack_flash" - in model.model.layers[0].self_attn.forward.__module__ + "torch.jit" + in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access ) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py new file mode 100644 index 000000000..62fb63c47 --- /dev/null +++ b/tests/e2e/test_llama_pretrain.py @@ -0,0 +1,67 @@ +""" +E2E tests for llama pretrain +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestPretrainLlama(unittest.TestCase): + """ + Test case for Llama models w pretraining + """ + + @with_temp_dir + def test_pretrain_w_sample_packing(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "flash_attention": True, + "sequence_len": 1024, + "sample_packing": True, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "pretraining_dataset": [ + { + "path": "allenai/c4", + "name": "en", + "type": "pretrain", + } + ], + "max_steps": 5, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index c79652bef..4c6fdaaa9 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase): "sequence_len": 1024, "load_in_8bit": True, "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, + "lora_r": 8, + "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, "val_set_size": 0.1, @@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase): "type": "alpaca", }, ], - "num_epochs": 2, + "num_epochs": 1, "micro_batch_size": 8, "gradient_accumulation_steps": 1, "output_dir": temp_dir, diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py new file mode 100644 index 000000000..119dd3d7c --- /dev/null +++ b/tests/e2e/test_optimizers.py @@ -0,0 +1,67 @@ +""" +E2E tests for custom optimizers using Llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestCustomOptimizers(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_optimi_adamw(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "optimi_adamw", + "lr_scheduler": "cosine", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py new file mode 100644 index 000000000..cca48b1cf --- /dev/null +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -0,0 +1,156 @@ +""" +tests for chat_template prompt strategy +""" + +import unittest + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.prompt_strategies.dpo.chat_template import default +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="assistant_dataset") +def fixture_assistant_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "messages": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "hello", + }, + { + "role": "user", + "content": "goodbye", + }, + ], + "chosen": { + "role": "assistant", + "content": "goodbye", + }, + "rejected": { + "role": "assistant", + "content": "party on", + }, + } + ] + ) + + +@pytest.fixture(name="custom_assistant_dataset") +def fixture_custom_assistant_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversation": [ + { + "speaker": "human", + "text": "hello", + }, + { + "speaker": "agent", + "text": "hello", + }, + { + "speaker": "human", + "text": "goodbye", + }, + ], + "better": { + "speaker": "agent", + "text": "goodbye", + }, + "worse": { + "speaker": "agent", + "text": "party on", + }, + } + ] + ) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + tokenizer.eos_token = "<|eot_id|>" + + return tokenizer + + +class TestAssistantDPOChatTemplateLlama3: + """ + Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. + """ + + def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "chat_template": "llama3", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "better", + "field_rejected": "worse", + "message_field_role": "speaker", + "message_field_content": "text", + "roles": { + "user": ["human"], + "assistant": ["agent"], + "system": ["sys"], + }, + } + ], + } + ) + ) + result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer) + assert result["prompt"] == ( + "<|begin_of_text|>" + + "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>" + + "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert result["chosen"] == "goodbye<|eot_id|>" + assert result["rejected"] == "party on<|eot_id|>" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index fb623a43d..5d517585f 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -24,7 +24,7 @@ class TestPretrainingPacking(unittest.TestCase): def test_packing_stream_dataset(self): # pylint: disable=duplicate-code dataset = load_dataset( - "c4", + "allenai/c4", "en", streaming=True, )["train"] @@ -33,7 +33,7 @@ class TestPretrainingPacking(unittest.TestCase): { "pretraining_dataset": [ { - "path": "c4", + "path": "allenai/c4", "name": "en", "type": "pretrain", }