Merge branch 'main' into cj_tokenizer_default_prompt_template
This commit is contained in:
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -37,6 +37,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -19,7 +19,6 @@ jobs:
|
|||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -33,8 +32,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -80,7 +80,6 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -94,8 +93,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -136,7 +136,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
8
.github/workflows/nightlies.yml
vendored
8
.github/workflows/nightlies.yml
vendored
@@ -18,7 +18,6 @@ jobs:
|
|||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -32,8 +31,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -80,7 +80,6 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -94,8 +93,9 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
8
.github/workflows/tests.yml
vendored
8
.github/workflows/tests.yml
vendored
@@ -57,6 +57,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pytest --ignore=tests/e2e/ tests/
|
pytest --ignore=tests/e2e/ tests/
|
||||||
|
|
||||||
|
- name: cleanup pip cache
|
||||||
|
run: |
|
||||||
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
@@ -87,7 +91,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.0
|
pytorch: 2.3.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -99,7 +103,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal jinja2
|
pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ Features:
|
|||||||
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
|
- [Unsloth](./docs/unsloth.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Debugging Axolotl](#debugging-axolotl)
|
- [Debugging Axolotl](#debugging-axolotl)
|
||||||
@@ -333,7 +334,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
|
|||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
|
|
||||||
See [these docs](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
||||||
|
|
||||||
### Config
|
### Config
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ website:
|
|||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
- docs/mac.qmd
|
- docs/mac.qmd
|
||||||
- docs/multi-node.qmd
|
- docs/multi-node.qmd
|
||||||
|
- docs/unsloth.qmd
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
- section: "Reference"
|
- section: "Reference"
|
||||||
|
|||||||
@@ -24,13 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN pip install pytest
|
RUN pip install -r requirements-tests.txt
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest /workspace/axolotl/tests/e2e/patched/
|
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
|
||||||
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
19
docs/torchao.qmd
Normal file
19
docs/torchao.qmd
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
---
|
||||||
|
title: "PyTorch ao"
|
||||||
|
description: "Custom data types and layouts for training and inference"
|
||||||
|
---
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Stable Release from the PyTorch index
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Nightly release
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
|
||||||
|
```
|
||||||
49
docs/unsloth.qmd
Normal file
49
docs/unsloth.qmd
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
---
|
||||||
|
title: "Unsloth"
|
||||||
|
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||||
|
---
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||||
|
standard industry baselines.
|
||||||
|
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
|
||||||
|
to date libraries.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
|
||||||
|
pip install --no-deps --force-reinstall xformers==0.0.26.post1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using unsloth w Axolotl
|
||||||
|
|
||||||
|
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||||
|
|
||||||
|
Our unsloth integration is currently limited to the following model architectures:
|
||||||
|
- llama
|
||||||
|
|
||||||
|
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
||||||
|
```yaml
|
||||||
|
unsloth_lora_mlp: true
|
||||||
|
unsloth_lora_qkv: true
|
||||||
|
unsloth_lora_o: true
|
||||||
|
```
|
||||||
|
|
||||||
|
These options are composable and can be used with multi-gpu finetuning
|
||||||
|
```
|
||||||
|
unsloth_cross_entropy_loss: true
|
||||||
|
unsloth_rms_norm: true
|
||||||
|
unsloth_rope: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
- Single GPU only; e.g. no multi-gpu support
|
||||||
|
- No deepspeed or FSDP support (requires multi-gpu)
|
||||||
|
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
||||||
|
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
||||||
|
- No MoE support.
|
||||||
@@ -171,7 +171,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
"# By using the ! the comand will be executed as a bash command\n",
|
||||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -188,7 +188,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
"# By using the ! the comand will be executed as a bash command\n",
|
||||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
81
examples/llama-3/instruct-dpo-lora-8b.yml
Normal file
81
examples/llama-3/instruct-dpo-lora-8b.yml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
chat_template: llama3
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
base_model: NousResearch/Meta-Llama-3-8B-Instruct
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
@@ -15,6 +15,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
pytest
|
pytest
|
||||||
|
pytest-xdist
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.42.3
|
transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.32.0
|
accelerate==0.32.0
|
||||||
@@ -12,11 +12,11 @@ fire
|
|||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.19.1
|
datasets==2.19.1
|
||||||
flash-attn==2.5.8
|
flash-attn==2.6.1
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.26.post1
|
xformers==0.0.27
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
|
|||||||
21
setup.py
21
setup.py
@@ -29,9 +29,10 @@ def parse_requirements():
|
|||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# don't install xformers on MacOS
|
# don't install xformers on MacOS
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
@@ -49,12 +50,14 @@ def parse_requirements():
|
|||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 3):
|
if (major, minor) >= (2, 3):
|
||||||
pass
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
elif (major, minor) >= (2, 2):
|
elif (major, minor) >= (2, 2):
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
else:
|
else:
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.23.post1")
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
@@ -77,10 +80,10 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.5.8",
|
"flash-attn==2.6.1",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
||||||
@@ -101,5 +104,11 @@ setup(
|
|||||||
"galore": [
|
"galore": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
],
|
],
|
||||||
|
"optimizers": [
|
||||||
|
"galore_torch",
|
||||||
|
"lion-pytorch==0.1.2",
|
||||||
|
"lomo-optim==0.1.1",
|
||||||
|
"torch-optimi==0.2.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
"n_gpu": os.environ.get("WORLD_SIZE", 1),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
150
src/axolotl/core/tokenizer_utils.py
Normal file
150
src/axolotl/core/tokenizer_utils.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""
|
||||||
|
helper functions for fixing the embeddings/tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
||||||
|
"""
|
||||||
|
Many of the newer models have reserved tokens that are not trained.
|
||||||
|
"""
|
||||||
|
embedding_matrix = model.get_input_embeddings().weight
|
||||||
|
lm_head_matrix = model.get_output_embeddings().weight
|
||||||
|
|
||||||
|
# Get untrained tokens
|
||||||
|
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
|
||||||
|
where_untrained = torch.where(indicator_untrained)[0]
|
||||||
|
n_untrained = where_untrained.shape[0]
|
||||||
|
n_trained = embedding_matrix.shape[0] - n_untrained
|
||||||
|
|
||||||
|
# Get set and actual tokens
|
||||||
|
where_untrained = where_untrained.tolist()
|
||||||
|
if len(where_untrained) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove untrained indices where it's longer
|
||||||
|
|
||||||
|
where_untrained_set = frozenset(where_untrained)
|
||||||
|
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
||||||
|
# Remove None items in actual_bad_tokens
|
||||||
|
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
|
||||||
|
|
||||||
|
# Check if tokenizer and training datasets have bad tokens
|
||||||
|
if_bad_first = False
|
||||||
|
if_bad_second = False
|
||||||
|
# Check tokenizer's chat template for any untrained tokens
|
||||||
|
chat_template = getattr(tokenizer, "chat_template", None)
|
||||||
|
if chat_template is not None:
|
||||||
|
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
||||||
|
|
||||||
|
# Check the first 250, last 250 input_ids
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 250)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||||
|
if if_bad:
|
||||||
|
if_bad_second = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check last 250
|
||||||
|
if not if_bad_second:
|
||||||
|
left = max(size_dataset - 250, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||||
|
if if_bad:
|
||||||
|
if_bad_second = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check if bad tokens exists!
|
||||||
|
if not if_bad_first and not if_bad_second:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Count all the possible bad tokens
|
||||||
|
final_counts = np.zeros(
|
||||||
|
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
def mapping(examples):
|
||||||
|
input_ids = examples["input_ids"]
|
||||||
|
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
|
||||||
|
np.add.at(final_counts, counter, 1)
|
||||||
|
|
||||||
|
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
||||||
|
|
||||||
|
# Get sum of all items
|
||||||
|
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
||||||
|
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
||||||
|
|
||||||
|
# Remove bad tokens
|
||||||
|
sum_embedding -= torch.sum(
|
||||||
|
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||||
|
)
|
||||||
|
sum_lm_head -= torch.sum(
|
||||||
|
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find correct average by dividing by sum of trained tokens
|
||||||
|
mean_embedding = sum_embedding / n_trained
|
||||||
|
mean_lm_head = sum_lm_head / n_trained
|
||||||
|
|
||||||
|
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
|
||||||
|
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
|
||||||
|
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
||||||
|
mean_embedding = (
|
||||||
|
mean_embedding.repeat(
|
||||||
|
(
|
||||||
|
n_untrained,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* scaling
|
||||||
|
)
|
||||||
|
mean_lm_head = (
|
||||||
|
mean_lm_head.repeat(
|
||||||
|
(
|
||||||
|
n_untrained,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* scaling
|
||||||
|
)
|
||||||
|
where_null = scaling.ravel() == 0
|
||||||
|
mean_embedding[where_null] = 0
|
||||||
|
mean_lm_head[where_null] = 0
|
||||||
|
|
||||||
|
# Set them to the mean
|
||||||
|
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
|
||||||
|
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
for _ in range(3):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -226,6 +226,12 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
)
|
)
|
||||||
|
alternate_optimizer: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -284,26 +290,91 @@ class AxolotlTrainer(Trainer):
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if self.args.loraplus_lr_ratio is None:
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.alternate_optimizer
|
||||||
|
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||||
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n not in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
self.args,
|
self.args,
|
||||||
opt_model,
|
opt_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
loraplus_lr_embedding = getattr(
|
||||||
opt_model,
|
self.args, "loraplus_lr_embedding", None
|
||||||
optimizer_cls,
|
)
|
||||||
optimizer_kwargs,
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
loraplus_lr_ratio,
|
opt_model,
|
||||||
loraplus_lr_embedding,
|
optimizer_cls,
|
||||||
)
|
optimizer_kwargs,
|
||||||
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW(
|
||||||
|
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -1235,6 +1306,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"torch_compile_backend"
|
"torch_compile_backend"
|
||||||
] = self.cfg.torch_compile_backend
|
] = self.cfg.torch_compile_backend
|
||||||
|
if self.cfg.torch_compile_mode:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"torch_compile_mode"
|
||||||
|
] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if self.cfg.ddp_timeout:
|
if self.cfg.ddp_timeout:
|
||||||
@@ -1396,6 +1471,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
|
if self.cfg.optimizer in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
]:
|
||||||
|
# Set default so transformers doesn't throw
|
||||||
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
||||||
|
|
||||||
if self.cfg.optimizer == "lion_pytorch":
|
if self.cfg.optimizer == "lion_pytorch":
|
||||||
from lion_pytorch import Lion
|
from lion_pytorch import Lion
|
||||||
|
|
||||||
@@ -1424,6 +1509,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
if self.cfg.accelerator_config:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"accelerator_config"
|
||||||
|
] = self.cfg.accelerator_config
|
||||||
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
@@ -1621,6 +1711,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
@@ -1688,8 +1779,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["max_target_length"] = None
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
if self.cfg.rl == "dpo":
|
|
||||||
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
0
src/axolotl/integrations/__init__.py
Normal file
0
src/axolotl/integrations/__init__.py
Normal file
@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
|
|||||||
set_module_name(model, name, qkv)
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_cross_entropy():
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_rms_norm():
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(RMSNorm):
|
||||||
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(
|
||||||
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
@@ -104,35 +131,11 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
try:
|
patch_llama_cross_entropy()
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
LOG.warning(
|
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
try:
|
patch_llama_rms_norm()
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
|
||||||
"""Patched LLamaRMSNorm"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__(hidden_size, eps=eps)
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
|
||||||
except ImportError:
|
|
||||||
LOG.warning(
|
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedAttention(LlamaAttention):
|
class FusedAttention(LlamaAttention):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mistral_cross_entropy():
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def _make_sliding_window_causal_mask(
|
def _make_sliding_window_causal_mask(
|
||||||
bsz: int,
|
bsz: int,
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
|||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
|
"llama",
|
||||||
|
"mistral",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
@@ -24,12 +26,35 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None):
|
def patch_for_multipack(model_type, model_name=None):
|
||||||
|
if model_type == "gemmoe":
|
||||||
|
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||||
|
elif model_type == "deepseek_v2":
|
||||||
|
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||||
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
return
|
||||||
|
|
||||||
|
# retain for legacy
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
patch_mixtral_moe_forward_zero3()
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
elif model_type == "llama":
|
||||||
|
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "mistral":
|
||||||
|
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
elif model_type == "qwen2":
|
elif model_type == "qwen2":
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -58,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
elif model_type == "gemmoe":
|
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
|
||||||
elif model_type == "jamba":
|
|
||||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
|
||||||
elif model_type == "deepseek_v2":
|
|
||||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name, config_name, modeling_name):
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
"""module for patching with unsloth optimizations"""
|
"""module for patching with unsloth optimizations"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaFlashAttention2,
|
LlamaFlashAttention2,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
ORIGINAL_CEL_CODE = """ if labels is not None:
|
ORIGINAL_CEL_CODE = """ if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
@@ -97,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
|
|||||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch():
|
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||||
forward = get_forward_code()
|
if model_type == "llama":
|
||||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
forward = get_forward_code()
|
||||||
forward, _ = detab_code(forward)
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
forward, _ = detab_code(forward)
|
||||||
|
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||||
)
|
)
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"def forward(",
|
"def forward(",
|
||||||
"def fast_cross_entropy_loss_forward(",
|
"def fast_cross_entropy_loss_forward(",
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load imports necessary
|
# load imports necessary
|
||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
items_to_import = []
|
items_to_import = []
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
if item in forward:
|
if item in forward:
|
||||||
items_to_import.append(item)
|
items_to_import.append(item)
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
+ ", ".join(x for x in items_to_import)
|
+ ", ".join(x for x in items_to_import)
|
||||||
+ ")",
|
+ ")",
|
||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
print("patching unsloth fast_cross_entropy_loss")
|
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported model type")
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
@@ -179,12 +184,30 @@ def patch_self_attn_lora():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
print("patching unsloth attn lora")
|
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||||
LlamaFlashAttention2.forward = (
|
LlamaFlashAttention2.forward = (
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_rope_embeddings():
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from unsloth.kernels.rope_embedding import fast_rope_embedding
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb( # pylint: disable=unused-argument
|
||||||
|
q, # pylint: disable=invalid-name
|
||||||
|
k, # pylint: disable=invalid-name
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
position_ids=None,
|
||||||
|
unsqueeze_dim=1,
|
||||||
|
):
|
||||||
|
return fast_rope_embedding(q, k, cos, sin)
|
||||||
|
|
||||||
|
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
|
||||||
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||||
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
||||||
from unsloth.kernels import apply_lora_mlp_swiglu
|
from unsloth.kernels import apply_lora_mlp_swiglu
|
||||||
@@ -217,7 +240,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
|||||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||||
else:
|
else:
|
||||||
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||||
@@ -243,9 +266,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_qkv = original_apply_qkv
|
layer.self_attn.apply_qkv = original_apply_qkv
|
||||||
logging.warning(
|
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
|
||||||
"unable to apply unsloth lora qkv patch to layer %d", idx
|
|
||||||
)
|
|
||||||
if cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_o:
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
@@ -264,6 +285,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_o = apply_lora_o
|
layer.self_attn.apply_o = apply_lora_o
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_o = original_apply_o
|
layer.self_attn.apply_o = original_apply_o
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_unsloth_layernorm():
|
||||||
|
try:
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(nn.Module):
|
||||||
|
"""LlamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return Fast_RMS_Layernorm.apply(
|
||||||
|
hidden_states, self.weight, self.variance_epsilon, False
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with unsloth.kernels.rms_layernorm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("missing unsloth library")
|
||||||
|
|||||||
78
src/axolotl/prompt_strategies/dpo/chat_template.py
Normal file
78
src/axolotl/prompt_strategies/dpo/chat_template.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
DPO prompt strategies for using tokenizer chat templates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
|
||||||
|
def default(
|
||||||
|
cfg, dataset_idx=0, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
ds_cfg = cfg["datasets"][dataset_idx]
|
||||||
|
chat_template_str = chat_templates(cfg.chat_template)
|
||||||
|
|
||||||
|
field_messages = ds_cfg.get("field_messages", "messages")
|
||||||
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
|
field_message_role = ds_cfg.get("message_field_role", "role")
|
||||||
|
field_message_content = ds_cfg.get("message_field_content", "content")
|
||||||
|
role_map_inv = ds_cfg.get(
|
||||||
|
"roles",
|
||||||
|
{
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
role_map = {}
|
||||||
|
for target, sources in role_map_inv.items():
|
||||||
|
for source in sources:
|
||||||
|
role_map[source] = target
|
||||||
|
|
||||||
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
messages = sample[field_messages]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": role_map[m[field_message_role]],
|
||||||
|
"content": m[field_message_content],
|
||||||
|
}
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
chosen = {
|
||||||
|
"role": role_map[sample[field_chosen][field_message_role]],
|
||||||
|
"content": sample[field_chosen][field_message_content],
|
||||||
|
}
|
||||||
|
rejected = {
|
||||||
|
"role": role_map[sample[field_rejected][field_message_role]],
|
||||||
|
"content": sample[field_rejected][field_message_content],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
result["prompt"] = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result["chosen"] = tokenizer.apply_chat_template(
|
||||||
|
[chosen],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||||
|
result["chosen"] = result["chosen"][chosen_strip_index:]
|
||||||
|
|
||||||
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
|
[rejected],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||||
|
result["rejected"] = result["rejected"][rejected_strip_index:]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -19,6 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.core.tokenizer_utils import fix_untrained_tokens
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
@@ -52,6 +53,15 @@ class TrainDatasetMeta:
|
|||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
|
# enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
torch_version = torch.__version__.split(".")
|
||||||
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
|
if torch_major == 2 and torch_minor >= 2:
|
||||||
|
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
|
os.environ[
|
||||||
|
"PYTORCH_CUDA_ALLOC_CONF"
|
||||||
|
] = "expandable_segments:True,roundup_power2_divisions:16"
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
@@ -114,6 +124,13 @@ def train(
|
|||||||
total_num_steps,
|
total_num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.fix_untrained_tokens:
|
||||||
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
|
if cfg.local_rank == 0:
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import os
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from importlib.metadata import version
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
@@ -84,6 +85,7 @@ class PretrainingDataset(BaseModel):
|
|||||||
split: Optional[str] = "train"
|
split: Optional[str] = "train"
|
||||||
text_column: Optional[str] = "text"
|
text_column: Optional[str] = "text"
|
||||||
type: Optional[str] = "pretrain"
|
type: Optional[str] = "pretrain"
|
||||||
|
trust_remote_code: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedPrompterType(BaseModel):
|
class UserDefinedPrompterType(BaseModel):
|
||||||
@@ -125,6 +127,8 @@ class SFTDataset(BaseModel):
|
|||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
drop_system_message: Optional[bool] = None
|
drop_system_message: Optional[bool] = None
|
||||||
|
|
||||||
|
trust_remote_code: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
"""User defined typing for DPO"""
|
"""User defined typing for DPO"""
|
||||||
@@ -165,6 +169,7 @@ class KTODataset(BaseModel):
|
|||||||
split: Optional[str] = None
|
split: Optional[str] = None
|
||||||
type: Optional[Union[UserDefinedKTOType, str]] = None
|
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
|
trust_remote_code: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
class RLType(str, Enum):
|
||||||
@@ -350,7 +355,16 @@ class HyperparametersConfig(BaseModel):
|
|||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[OptimizerNames, Literal["lion_pytorch"]]
|
Union[
|
||||||
|
OptimizerNames,
|
||||||
|
Literal[
|
||||||
|
"lion_pytorch",
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
],
|
||||||
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||||
@@ -513,6 +527,8 @@ class AxolotlInputConfig(
|
|||||||
dataloader_prefetch_factor: Optional[int] = None
|
dataloader_prefetch_factor: Optional[int] = None
|
||||||
dataloader_drop_last: Optional[bool] = None
|
dataloader_drop_last: Optional[bool] = None
|
||||||
|
|
||||||
|
accelerator_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
remove_unused_columns: Optional[bool] = None
|
remove_unused_columns: Optional[bool] = None
|
||||||
|
|
||||||
push_dataset_to_hub: Optional[str] = None
|
push_dataset_to_hub: Optional[str] = None
|
||||||
@@ -599,6 +615,8 @@ class AxolotlInputConfig(
|
|||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
unsloth_lora_o: Optional[bool] = None
|
unsloth_lora_o: Optional[bool] = None
|
||||||
|
unsloth_rms_norm: Optional[bool] = None
|
||||||
|
unsloth_rope: Optional[bool] = None
|
||||||
|
|
||||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
fsdp: Optional[List[str]] = None
|
fsdp: Optional[List[str]] = None
|
||||||
@@ -611,6 +629,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
torch_compile: Optional[bool] = None
|
torch_compile: Optional[bool] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
|
torch_compile_mode: Optional[
|
||||||
|
Literal["default", "reduce-overhead", "max-autotune"]
|
||||||
|
] = None
|
||||||
|
|
||||||
max_steps: Optional[int] = None
|
max_steps: Optional[int] = None
|
||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
@@ -651,6 +672,8 @@ class AxolotlInputConfig(
|
|||||||
] = None
|
] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
|
fix_untrained_tokens: Optional[bool] = None
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: Optional[bool] = None
|
is_preprocess: Optional[bool] = None
|
||||||
|
|
||||||
@@ -716,6 +739,24 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_pretraining_split_batches_accelerate(cls, data):
|
||||||
|
# alternatively set ACCELERATE_SPLIT_BATCHES=False
|
||||||
|
if data.get("pretraining_dataset"):
|
||||||
|
accelerator_config = data.get("accelerator_config", {})
|
||||||
|
if not accelerator_config:
|
||||||
|
data["accelerator_config"] = {
|
||||||
|
"split_batches": False,
|
||||||
|
"dispatch_batches": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if accelerator_config.get("split_batches") is None:
|
||||||
|
data["accelerator_config"]["split_batches"] = False
|
||||||
|
if accelerator_config.get("dispatch_batches") is None:
|
||||||
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_gptq_w_revision(cls, data):
|
def check_gptq_w_revision(cls, data):
|
||||||
@@ -834,7 +875,7 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_adamw_optimizer_params(self):
|
def check_adamw_optimizer_params(self):
|
||||||
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
|
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
|
||||||
not self.optimizer or "adamw" not in self.optimizer.value
|
not self.optimizer or "adamw" not in str(self.optimizer).lower()
|
||||||
):
|
):
|
||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
@@ -1126,6 +1167,55 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("either datasets or pretraining_dataset is required")
|
raise ValueError("either datasets or pretraining_dataset is required")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_xentropy_patch_conflicts(cls, data):
|
||||||
|
if data.get("flash_attn_cross_entropy") and data.get(
|
||||||
|
"unsloth_cross_entropy_loss"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_qlora_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_unsloth_xformers_version(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
xformers_version = version("xformers")
|
||||||
|
if xformers_version == "0.0.27":
|
||||||
|
raise ValueError(
|
||||||
|
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_torch_compile_deepspeed(cls, data):
|
||||||
|
if data.get("deepspeed") and data.get("torch_compile"):
|
||||||
|
raise ValueError(
|
||||||
|
"torch_compile should be set within your deepspeed config file"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
@@ -1177,3 +1267,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("deepspeed") and data.get("fsdp"):
|
if data.get("deepspeed") and data.get("fsdp"):
|
||||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_multigpu_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
capabilities = data.get("capabilities")
|
||||||
|
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""data handling specific to DPO"""
|
"""data handling specific to DPO"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
|
|||||||
"Please make sure to point to a GPTQ model."
|
"Please make sure to point to a GPTQ model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cfg.gptq and quant_config_exists:
|
if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
"model_config.quantization_config is set but `gptq` flag is not. "
|
||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
||||||
@@ -347,6 +347,31 @@ def load_model(
|
|||||||
and cfg.sample_packing
|
and cfg.sample_packing
|
||||||
):
|
):
|
||||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model:
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
patch_llama_cross_entropy,
|
||||||
|
patch_llama_rms_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attn_cross_entropy:
|
||||||
|
patch_llama_cross_entropy()
|
||||||
|
if cfg.flash_attn_rms_norm:
|
||||||
|
patch_llama_rms_norm()
|
||||||
|
elif cfg.unsloth_rms_norm:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||||
|
|
||||||
|
patch_unsloth_layernorm()
|
||||||
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import (
|
||||||
|
integrate_cross_entropy_loss_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||||
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora()
|
||||||
elif cfg.is_llama_derived_model:
|
elif cfg.is_llama_derived_model:
|
||||||
# Modify all llama derived models in one block
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
@@ -371,6 +396,12 @@ def load_model(
|
|||||||
rms_norm=cfg.flash_attn_rms_norm,
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
use_shifted_sparse_attn=True,
|
use_shifted_sparse_attn=True,
|
||||||
)
|
)
|
||||||
|
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
|
||||||
|
replace_llama_attn_with_flash_attn(
|
||||||
|
packed=False,
|
||||||
|
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||||
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
|
)
|
||||||
elif cfg.xformers_attention:
|
elif cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
@@ -393,7 +424,7 @@ def load_model(
|
|||||||
if cfg.unsloth_cross_entropy_loss:
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||||
|
|
||||||
integrate_cross_entropy_loss_patch()
|
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||||
|
|
||||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
@@ -401,23 +432,12 @@ def load_model(
|
|||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
# Modify mistral derived models
|
# Modify mistral derived models
|
||||||
if (
|
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
|
||||||
cfg.model_config_type == "mistral"
|
|
||||||
and cfg.flash_attention
|
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
replace_mistral_attn_with_flash_attn,
|
patch_mistral_cross_entropy,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching mistral with flash attention")
|
patch_mistral_cross_entropy()
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
|
||||||
|
|
||||||
LOG.info("patching _expand_mask")
|
|
||||||
hijack_expand_mask()
|
|
||||||
|
|
||||||
model_kwargs: Dict[str, Any] = {}
|
model_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
@@ -599,9 +619,12 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
):
|
):
|
||||||
from transformers import LlamaForCausalLM
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
skip_move_to_device = True
|
||||||
|
if "device_map" in model_kwargs:
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -634,7 +657,11 @@ def load_model(
|
|||||||
base_model,
|
base_model,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif (
|
||||||
|
model_type
|
||||||
|
and model_type != "AutoModelForCausalLM"
|
||||||
|
and not cfg.trust_remote_code
|
||||||
|
):
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -675,6 +702,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
# disabling either of these two still leads to VRAM spike before setting back down
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
@@ -849,6 +877,15 @@ def load_model(
|
|||||||
|
|
||||||
integrate_lora_patch(model, cfg)
|
integrate_lora_patch(model, cfg)
|
||||||
|
|
||||||
|
if cfg.unsloth_rope:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
|
||||||
|
|
||||||
|
integrate_rope_embeddings()
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
|||||||
"""Helper function to process and color tokens."""
|
"""Helper function to process and color tokens."""
|
||||||
colored_tokens = [
|
colored_tokens = [
|
||||||
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||||
for token in tokenizer.encode(tokens)
|
for token in tokenizer.encode(tokens, add_special_tokens=False)
|
||||||
]
|
]
|
||||||
return colored_tokens
|
return colored_tokens
|
||||||
|
|
||||||
|
|||||||
@@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
|
|
||||||
if (
|
if cfg.model_config_type == "mamba":
|
||||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
|
||||||
) or cfg.model_config_type == "mamba":
|
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
|
|||||||
87
tests/e2e/patched/test_fa_xentropy.py
Normal file
87
tests/e2e/patched/test_fa_xentropy.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from importlib import reload
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reload_transformers():
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
yield
|
||||||
|
reload(transformers.models.llama.modeling_llama)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFAXentropyLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA w multipack
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora_packing_fa_cross_entropy(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"flash_attn_cross_entropy": True,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.2,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
"torch.jit"
|
||||||
in model.model.layers[0].self_attn.forward.__module__
|
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|||||||
67
tests/e2e/test_llama_pretrain.py
Normal file
67
tests/e2e/test_llama_pretrain.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for llama pretrain
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPretrainLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models w pretraining
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_pretrain_w_sample_packing(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"pretraining_dataset": [
|
||||||
|
{
|
||||||
|
"path": "allenai/c4",
|
||||||
|
"name": "en",
|
||||||
|
"type": "pretrain",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 32,
|
"lora_r": 8,
|
||||||
"lora_alpha": 64,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 2,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
|||||||
67
tests/e2e/test_optimizers.py
Normal file
67
tests/e2e/test_optimizers.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for custom optimizers using Llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomOptimizers(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_optimi_adamw(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "optimi_adamw",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
tests for chat_template prompt strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="assistant_dataset")
|
||||||
|
def fixture_assistant_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"chosen": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
"rejected": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "party on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="custom_assistant_dataset")
|
||||||
|
def fixture_custom_assistant_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversation": [
|
||||||
|
{
|
||||||
|
"speaker": "human",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker": "human",
|
||||||
|
"text": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"better": {
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "goodbye",
|
||||||
|
},
|
||||||
|
"worse": {
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "party on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
tokenizer.eos_token = "<|eot_id|>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssistantDPOChatTemplateLlama3:
|
||||||
|
"""
|
||||||
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
transform_fn = default(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
transform_fn = default(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "better",
|
||||||
|
"field_rejected": "worse",
|
||||||
|
"message_field_role": "speaker",
|
||||||
|
"message_field_content": "text",
|
||||||
|
"roles": {
|
||||||
|
"user": ["human"],
|
||||||
|
"assistant": ["agent"],
|
||||||
|
"system": ["sys"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -24,7 +24,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
def test_packing_stream_dataset(self):
|
def test_packing_stream_dataset(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"c4",
|
"allenai/c4",
|
||||||
"en",
|
"en",
|
||||||
streaming=True,
|
streaming=True,
|
||||||
)["train"]
|
)["train"]
|
||||||
@@ -33,7 +33,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
"pretraining_dataset": [
|
"pretraining_dataset": [
|
||||||
{
|
{
|
||||||
"path": "c4",
|
"path": "allenai/c4",
|
||||||
"name": "en",
|
"name": "en",
|
||||||
"type": "pretrain",
|
"type": "pretrain",
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user