Compare commits
52 Commits
mm3
...
upgrade_li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1bf20f990 | ||
|
|
bb648cbc63 | ||
|
|
8b0bca4842 | ||
|
|
d36baf44b1 | ||
|
|
16c8140d20 | ||
|
|
21c25cf7bc | ||
|
|
32288a5d3c | ||
|
|
052a9a79b4 | ||
|
|
3591bcfaf9 | ||
|
|
dc1de7d81b | ||
|
|
d4dbfa02fe | ||
|
|
5c7e89105d | ||
|
|
74db2a1bae | ||
|
|
e62554c419 | ||
|
|
32c60765ef | ||
|
|
8c3a727f9d | ||
|
|
107b67b852 | ||
|
|
bfc77b0f36 | ||
|
|
e1e0556c99 | ||
|
|
d3c45d27b5 | ||
|
|
2501c1a6a3 | ||
|
|
1d6a5e2bd6 | ||
|
|
718cfb2dd1 | ||
|
|
9bd5f7d015 | ||
|
|
5c629ee444 | ||
|
|
955cca41fc | ||
|
|
e12a2130e9 | ||
|
|
67f744dc8c | ||
|
|
f62e23737b | ||
|
|
54673fd6ca | ||
|
|
6d9a3c4d81 | ||
|
|
335027f155 | ||
|
|
ec4272c3a0 | ||
|
|
68b1369de9 | ||
|
|
cd2d89f467 | ||
|
|
1834cdc364 | ||
|
|
ac128b7b1d | ||
|
|
31591bd94c | ||
|
|
d20b48a61e | ||
|
|
09bf1ceacc | ||
|
|
df359c8a6e | ||
|
|
76883851d2 | ||
|
|
922db77521 | ||
|
|
e73b8dff8d | ||
|
|
2fbc6b0c64 | ||
|
|
8159cbd1ab | ||
|
|
979534c851 | ||
|
|
6d3caadf90 | ||
|
|
dee77232fe | ||
|
|
a560593b1d | ||
|
|
e8d3da0081 | ||
|
|
4ca0a47cfb |
14
.github/workflows/base.yml
vendored
14
.github/workflows/base.yml
vendored
@@ -28,7 +28,19 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
14
.github/workflows/main.yml
vendored
14
.github/workflows/main.yml
vendored
@@ -27,7 +27,12 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -84,7 +89,12 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
13
.github/workflows/multi-gpu-e2e.yml
vendored
13
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,10 +21,17 @@ jobs:
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
|
||||
14
.github/workflows/nightlies.yml
vendored
14
.github/workflows/nightlies.yml
vendored
@@ -26,7 +26,12 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -83,7 +88,12 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
run: |
|
||||
pip3 install wheel packaging
|
||||
pip3 install -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
|
||||
15
.github/workflows/tests-nightly.yml
vendored
15
.github/workflows/tests-nightly.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.0"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -47,13 +47,14 @@ jobs:
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install -U -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
@@ -81,17 +82,17 @@ jobs:
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
nightly_build: "true"
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.5.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
78
.github/workflows/tests.yml
vendored
78
.github/workflows/tests.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.0"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -49,16 +49,20 @@ jobs:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 show torch
|
||||
pip3 install -U -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
@@ -68,33 +72,67 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests:
|
||||
docker-e2e-tests-1st:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 60
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.63.64 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb
|
||||
known_third_party=wandb,comet_ml
|
||||
|
||||
23
README.md
23
README.md
@@ -14,7 +14,7 @@ Features:
|
||||
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb or mlflow
|
||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||
- And more!
|
||||
|
||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation
|
||||
# fastchat conversation (deprecation soon, use chat_template)
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
@@ -515,6 +515,22 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
##### Comet Logging
|
||||
|
||||
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
use_comet:
|
||||
comet_api_key:
|
||||
comet_workspace:
|
||||
comet_project_name:
|
||||
comet_experiment_key:
|
||||
comet_mode:
|
||||
comet_online:
|
||||
comet_experiment_config:
|
||||
```
|
||||
|
||||
##### Special Tokens
|
||||
|
||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||
@@ -546,7 +562,8 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
|
||||
@@ -23,11 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN pip install causal_conv1d
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
@@ -37,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install -r requirements-tests.txt
|
||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
@stub.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=45 * 60,
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072 * N_GPUS,
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
@stub.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=45 * 60,
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
)
|
||||
|
||||
@@ -14,15 +14,6 @@
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
|
||||
@@ -24,15 +24,6 @@
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
|
||||
@@ -20,15 +20,6 @@
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
|
||||
@@ -7,8 +7,8 @@ load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
datasets:
|
||||
- path: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
shards: 10
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
@@ -20,7 +20,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN pip install causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -83,13 +83,14 @@ lora_on_cpu: true
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
@@ -123,6 +124,48 @@ datasets:
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Using chat template
|
||||
- path: ...
|
||||
# Set type to `chat_template` to use this strategy
|
||||
type: chat_template
|
||||
# Specify the name of the chat template to use
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
chat_template: tokenizer_default
|
||||
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
|
||||
chat_template_jinja:
|
||||
# The key in the data example that contains the messages. Default is "messages".
|
||||
field_messages: messages
|
||||
# The key in the message turn that contains the role. Default is "role".
|
||||
message_field_role: role
|
||||
# The key in the message turn that contains the content. Default is "content".
|
||||
message_field_content: content
|
||||
# Optional[Dict[str, List]]. Roles mapping for the messages.
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
assistant: ["gpt", "assistant", "ai"]
|
||||
system: ["system"]
|
||||
|
||||
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["gpt", "assistant"]
|
||||
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn: train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
message_field_training_detail: train_detail
|
||||
|
||||
|
||||
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
@@ -141,9 +184,16 @@ test_datasets:
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
rl:
|
||||
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
# Currently supports chatml and inst (mistral/mixtral)
|
||||
chat_template: chatml
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
|
||||
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
|
||||
chat_template: tokenizer_default
|
||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||
chat_template_jinja: null
|
||||
# Changes the default system message
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
@@ -265,8 +315,21 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
||||
# mlflow configuration if you're using it
|
||||
mlflow_tracking_uri: # URI to mlflow
|
||||
mlflow_experiment_name: # Your experiment name
|
||||
mlflow_run_name: # Your run name
|
||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||
|
||||
# Comet configuration if you're using it
|
||||
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
||||
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
||||
use_comet: # Enable or disable Comet integration.
|
||||
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
||||
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
||||
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
||||
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
||||
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
@@ -301,7 +364,7 @@ max_steps:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
@@ -6,6 +6,8 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
|
||||
|
||||
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
@@ -69,3 +71,138 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
```
|
||||
|
||||
|
||||
## chat_template
|
||||
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
See `config.qmd` for full configs and supported templates.
|
||||
|
||||
### Migrating from sharegpt
|
||||
|
||||
Most configs can be adapted as follows:
|
||||
|
||||
```yaml
|
||||
# old
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
|
||||
# new (if using tokenizer's chat_template)
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
# new (if setting a new chat_template like chatml, gemma, etc)
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
```
|
||||
|
||||
We recommend checking the below examples for other usecases.
|
||||
|
||||
### Examples
|
||||
|
||||
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
||||
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{
|
||||
"conversations": [
|
||||
{"from": "system", "value": "You are an AI assistant.", "train": false},
|
||||
{"from": "human", "value": "Hello", "train": false},
|
||||
{"from": "assistant", "value": "Hello", "train": true},
|
||||
{"from": "human", "value": "How are you?", "train": true},
|
||||
{
|
||||
"from": "assistant",
|
||||
"value": "I'm doing very well, thank you!",
|
||||
"train_detail": [
|
||||
{"begin_offset": 0, "end_offset": 8, "train": false},
|
||||
{"begin_offset": 9, "end_offset": 18, "train": true},
|
||||
{"begin_offset": 19, "end_offset": 30, "train": false},
|
||||
],
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "I'm doing very well, thank you!",
|
||||
"train": true,
|
||||
},
|
||||
{"from": "assistant", "value": "Hi there!", "train": true}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The configuration would look like:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: tokenizer_default
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
roles_to_train: []
|
||||
train_on_eos: turn
|
||||
message_field_training: train
|
||||
message_field_training_detail: train_detail
|
||||
```
|
||||
|
||||
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.
|
||||
|
||||
@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
|
||||
### Background
|
||||
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
>[!Important]
|
||||
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
|
||||
|
||||
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```jsonc
|
||||
// .vscode/launch.json
|
||||
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug axolotl prompt - sharegpt",
|
||||
"name": "Debug axolotl prompt - chat_template",
|
||||
"type": "python",
|
||||
"module": "accelerate.commands.launch",
|
||||
"request": "launch",
|
||||
"args": [
|
||||
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
</div>
|
||||
<br>
|
||||
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
|
||||
|
||||
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
||||
|
||||
@@ -16,7 +16,10 @@ chat_template: deepseek_v2
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -11,8 +11,11 @@ chat_template: gemma
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: gemma
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
|
||||
63
examples/gemma2/reward-model.yaml
Normal file
63
examples/gemma2/reward-model.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: google/gemma-2-2b
|
||||
model_type: AutoModelForSequenceClassification
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
reward_model: true
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
remove_unused_columns: false
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -4,11 +4,15 @@ tokenizer_type: AutoTokenizer
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
use_tensorboard: true
|
||||
chat_template: jamba
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: jamba
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: jamba-large-fsdp-qlora-ft
|
||||
|
||||
@@ -4,28 +4,26 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
strict: false
|
||||
|
||||
chat_template: llama3
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_project: check_liger_hf_GA_llama_fix-3
|
||||
wandb_entity: axolotl-ai
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_name: pr/fix333-tr4.46.1
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
|
||||
@@ -11,7 +11,6 @@ rl: dpo
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
chat_template: llama3
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
|
||||
@@ -10,7 +10,6 @@ chat_template: llama3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
chat_template: llama3
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
77
examples/llama-3/qlora-1b.yml
Normal file
77
examples/llama-3/qlora-1b.yml
Normal file
@@ -0,0 +1,77 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -10,7 +10,6 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
chat_template: phi_3
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
@@ -2,3 +2,4 @@ pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
tbparse
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.0
|
||||
transformers==4.45.1
|
||||
tokenizers>=0.19.1
|
||||
bitsandbytes==0.44.0
|
||||
accelerate==0.34.2
|
||||
datasets==2.21.0
|
||||
deepspeed==0.14.4
|
||||
peft==0.13.2
|
||||
transformers==4.46.1
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.0.1
|
||||
datasets==3.0.1
|
||||
deepspeed==0.15.3
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
@@ -16,7 +16,7 @@ flash-attn==2.6.3
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers==0.0.27
|
||||
xformers>=0.0.23.post1
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
colorama
|
||||
@@ -33,8 +33,8 @@ gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq>=0.2.5
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.3.0
|
||||
triton>=3.1.0
|
||||
liger-kernel==0.3.1
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
@@ -43,6 +43,14 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl==0.9.6
|
||||
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
# lm eval harness
|
||||
lm_eval==0.4.4
|
||||
langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.5.0
|
||||
|
||||
315
requirements_env.txt
Normal file
315
requirements_env.txt
Normal file
@@ -0,0 +1,315 @@
|
||||
accelerate==0.34.1
|
||||
addict==2.4.0
|
||||
aiofiles==23.2.1
|
||||
aiohttp==3.9.0
|
||||
aiosignal==1.3.1
|
||||
aiostream==0.5.2
|
||||
alembic==1.13.1
|
||||
annotated-types==0.6.0
|
||||
annoy==1.17.3
|
||||
ansible==6.7.0
|
||||
ansible-core==2.13.13
|
||||
ansible-vault==2.1.0
|
||||
anyio==3.7.1
|
||||
appdirs==1.4.4
|
||||
art==6.0
|
||||
asgiref==3.7.2
|
||||
async-timeout==4.0.2
|
||||
attrdict==2.0.1
|
||||
attrs==22.2.0
|
||||
awscli==1.32.75
|
||||
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
||||
backoff==2.2.1
|
||||
base58==2.1.1
|
||||
beartype==0.17.2
|
||||
bitnet==0.2.1
|
||||
bitsandbytes==0.42.0
|
||||
bittensor==6.7.0
|
||||
black==23.7.0
|
||||
blinker==1.7.0
|
||||
boto3==1.34.75
|
||||
botocore==1.34.75
|
||||
cachetools==5.3.3
|
||||
cachy==0.1.1
|
||||
certifi==2023.7.22
|
||||
cffi==1.16.0
|
||||
cfgv==3.3.1
|
||||
chai-guanaco==1.2.4
|
||||
charset-normalizer==3.2.0
|
||||
cleo==0.6.8
|
||||
click==8.1.7
|
||||
cloudpickle==2.0.0
|
||||
cohere==4.11.2
|
||||
colorama==0.4.4
|
||||
coloredlogs==15.0.1
|
||||
CoLT5-attention==0.10.20
|
||||
contextlib2==21.6.0
|
||||
contourpy==1.2.0
|
||||
cryptography==41.0.3
|
||||
cycler==0.12.1
|
||||
cytoolz==0.12.3
|
||||
databricks-cli==0.18.0
|
||||
dataclasses-json==0.5.7
|
||||
datasets==2.11.0
|
||||
ddt==1.6.0
|
||||
decorator==5.1.1
|
||||
deepspeed==0.15.0
|
||||
# Editable Git install with no remote (dialogpt==0.1)
|
||||
-e /Users/wing/Projects/ml/dialogpt/src
|
||||
dill==0.3.6
|
||||
distlib==0.3.6
|
||||
docker==7.0.0
|
||||
docker-pycreds==0.4.0
|
||||
docstring-parser==0.15
|
||||
docutils==0.16
|
||||
ecdsa==0.18.0
|
||||
einops==0.7.0
|
||||
einops-exts==0.0.4
|
||||
einx==0.1.3
|
||||
entrypoints==0.4
|
||||
eth-hash==0.6.0
|
||||
eth-keys==0.5.0
|
||||
eth-typing==4.0.0
|
||||
eth-utils==2.3.1
|
||||
evaluate==0.4.0
|
||||
exceptiongroup==1.1.1
|
||||
fastapi==0.109.2
|
||||
fastcore==1.5.29
|
||||
ffmpy==0.4.0
|
||||
filelock==3.12.2
|
||||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
||||
fire==0.5.0
|
||||
first==2.0.2
|
||||
flake8==7.0.0
|
||||
Flask==3.0.1
|
||||
fonttools==4.47.2
|
||||
frozendict==2.4.1
|
||||
frozenlist==1.3.3
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
fsspec==2023.6.0
|
||||
fuzzywuzzy==0.18.0
|
||||
gitdb==4.0.10
|
||||
GitPython==3.1.31
|
||||
google-pasta==0.2.0
|
||||
gradio==4.42.0
|
||||
gradio_client==1.3.0
|
||||
greenlet==2.0.2
|
||||
grpclib==0.4.7
|
||||
gunicorn==21.2.0
|
||||
h11==0.14.0
|
||||
h2==4.1.0
|
||||
hpack==4.0.0
|
||||
httpcore==0.17.3
|
||||
httpx==0.24.1
|
||||
huggingface-hub==0.23.4
|
||||
humanfriendly==10.0
|
||||
hyperframe==6.0.1
|
||||
identify==2.5.24
|
||||
idna==3.4
|
||||
immutables==0.20
|
||||
importlib-metadata==6.7.0
|
||||
importlib-resources==6.1.1
|
||||
inflection==0.5.1
|
||||
iniconfig==2.0.0
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
jmespath==1.0.1
|
||||
joblib==1.3.2
|
||||
jsonlines==3.1.0
|
||||
jsonschema==2.6.0
|
||||
kiwisolver==1.4.5
|
||||
langchain==0.0.144
|
||||
Levenshtein==0.24.0
|
||||
libcst==1.1.0
|
||||
liger-kernel==0.0.0
|
||||
lion-pytorch==0.1.2
|
||||
llama-cpp-python==0.1.36
|
||||
llvmlite==0.40.1
|
||||
local-attention==1.9.0
|
||||
loguru==0.7.0
|
||||
Mako==1.3.2
|
||||
Markdown==3.5.2
|
||||
markdown-it-py==3.0.0
|
||||
markdown2==2.4.10
|
||||
MarkupSafe==2.1.2
|
||||
marshmallow==3.19.0
|
||||
marshmallow-enum==1.5.1
|
||||
matplotlib==3.8.2
|
||||
mccabe==0.7.0
|
||||
mdurl==0.1.2
|
||||
MEGABYTE-pytorch==0.0.7
|
||||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
||||
mlflow==2.10.0
|
||||
modal==0.62.77
|
||||
more-itertools==10.2.0
|
||||
mpmath==1.2.1
|
||||
msgpack==1.0.7
|
||||
msgpack-numpy-opentensor==0.5.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
munch==2.5.0
|
||||
mypy==1.3.0
|
||||
mypy-extensions==1.0.0
|
||||
nest-asyncio==1.6.0
|
||||
netaddr==0.10.1
|
||||
networkx==3.0rc1
|
||||
nh3==0.2.14
|
||||
nodeenv==1.8.0
|
||||
nomic==2.0.2
|
||||
numba==0.57.1
|
||||
numexpr==2.8.4
|
||||
numpy==1.24.4
|
||||
oauthlib==3.2.2
|
||||
openai==0.27.4
|
||||
openapi==1.1.0
|
||||
openapi-schema-pydantic==1.2.4
|
||||
optimum==1.8.6
|
||||
orjson==3.10.7
|
||||
packaging==23.1
|
||||
pandas==2.0.0
|
||||
parameterized==0.9.0
|
||||
password-strength==0.0.3.post2
|
||||
pastel==0.1.1
|
||||
pathos==0.3.0
|
||||
pathspec==0.11.1
|
||||
pathtools==0.1.2
|
||||
peft==0.11.1
|
||||
pendulum==3.0.0
|
||||
Pillow==9.5.0
|
||||
pip-tools==1.11.0
|
||||
platformdirs==3.2.0
|
||||
pluggy==1.4.0
|
||||
poetry==0.7.1
|
||||
pox==0.3.2
|
||||
ppft==1.7.6.6
|
||||
pre-commit==3.3.2
|
||||
prettytable==3.10.0
|
||||
prompt-toolkit==3.0.39
|
||||
protobuf==3.20.2
|
||||
protobuf3-to-dict==0.1.5
|
||||
psutil==5.9.5
|
||||
psycopg==3.1.18
|
||||
PuLP==2.8.0
|
||||
py==1.11.0
|
||||
py-bip39-bindings==0.1.11
|
||||
py-cpuinfo==9.0.0
|
||||
py-ed25519-zebra-bindings==1.0.1
|
||||
py-sr25519-bindings==0.2.0
|
||||
pyarrow==11.0.0
|
||||
pyasn1==0.6.0
|
||||
pycodestyle==2.11.1
|
||||
pycparser==2.21
|
||||
pycryptodome==3.20.0
|
||||
pydantic==2.5.3
|
||||
pydantic_core==2.14.6
|
||||
pydub==0.25.1
|
||||
pyfiglet==0.8.post1
|
||||
pyflakes==3.2.0
|
||||
Pygments==2.15.1
|
||||
PyJWT==2.8.0
|
||||
pylev==1.4.0
|
||||
PyNaCl==1.5.0
|
||||
pynvml==11.5.0
|
||||
pyparsing==2.4.7
|
||||
pyrsistent==0.14.11
|
||||
pytest==8.0.2
|
||||
pytest-asyncio==0.23.4
|
||||
python-dateutil==2.8.2
|
||||
python-dotenv==1.0.1
|
||||
python-Levenshtein==0.24.0
|
||||
python-multipart==0.0.9
|
||||
pytz==2023.3
|
||||
PyYAML==6.0.1
|
||||
querystring-parser==1.2.4
|
||||
rapidfuzz==3.6.1
|
||||
regex==2023.6.3
|
||||
requests==2.31.0
|
||||
requests-toolbelt==0.8.0
|
||||
resolvelib==0.8.1
|
||||
responses==0.18.0
|
||||
retry==0.9.2
|
||||
rich==13.7.0
|
||||
rsa==4.7.2
|
||||
ruff==0.6.3
|
||||
s3transfer==0.10.1
|
||||
safetensors==0.4.5
|
||||
sagemaker==2.148.0
|
||||
scalecodec==1.2.7
|
||||
schedulefree==1.2.1
|
||||
schema==0.7.5
|
||||
scikit-learn==1.4.0
|
||||
scipy==1.9.3
|
||||
seaborn==0.13.2
|
||||
semantic-version==2.10.0
|
||||
sentencepiece==0.2.0
|
||||
sentry-sdk==1.19.1
|
||||
setproctitle==1.3.2
|
||||
shellingham==1.5.4
|
||||
shortuuid==1.0.11
|
||||
shtab==1.6.5
|
||||
sigtools==4.0.1
|
||||
six==1.16.0
|
||||
skypilot==0.4.1
|
||||
smdebug-rulesconfig==1.0.1
|
||||
smmap==5.0.0
|
||||
sniffio==1.3.0
|
||||
SQLAlchemy==1.4.47
|
||||
sqlparse==0.4.4
|
||||
starlette==0.36.3
|
||||
substrate-interface==1.5.2
|
||||
svgwrite==1.4.3
|
||||
sympy==1.11.1
|
||||
synchronicity==0.6.7
|
||||
tabulate==0.9.0
|
||||
tblib==1.7.0
|
||||
tenacity==8.2.2
|
||||
tensor-parallel==2.0.0
|
||||
termcolor==2.2.0
|
||||
text2art==0.2.0
|
||||
threadpoolctl==3.2.0
|
||||
tiktoken==0.6.0
|
||||
time-machine==2.14.1
|
||||
timm==0.9.16
|
||||
tokenizers==0.19.1
|
||||
tokenmonster==1.1.12
|
||||
toml==0.9.6
|
||||
tomli==2.0.1
|
||||
tomlkit==0.12.0
|
||||
toolz==0.12.1
|
||||
torch==2.2.0
|
||||
torchdata==0.6.1
|
||||
torchdiffeq==0.2.3
|
||||
TorchFix==0.4.0
|
||||
torchtext==0.15.2
|
||||
torchvision==0.17.0
|
||||
tqdm==4.66.2
|
||||
transformers==4.44.2
|
||||
trl==0.9.6
|
||||
typer==0.12.5
|
||||
types-certifi==2021.10.8.3
|
||||
types-requests==2.31.0.20240125
|
||||
types-setuptools==69.0.0.20240125
|
||||
types-toml==0.10.8.7
|
||||
typing==3.7.4.3
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.9.0
|
||||
tyro==0.5.18
|
||||
tzdata==2023.3
|
||||
unique-names-generator==1.0.2
|
||||
urllib3==2.2.2
|
||||
uvicorn==0.22.0
|
||||
vector_quantize_pytorch==1.14.1
|
||||
virtualenv==20.23.0
|
||||
voyager==2.0.2
|
||||
wandb==0.16.2
|
||||
watchfiles==0.21.0
|
||||
wavedrom==2.0.3.post3
|
||||
wcwidth==0.2.6
|
||||
websocket-client==1.7.0
|
||||
websockets==12.0
|
||||
Werkzeug==3.0.1
|
||||
wonderwords==2.2.0
|
||||
xxhash==3.2.0
|
||||
yarl==1.8.2
|
||||
zetascale==2.2.7
|
||||
zipp==3.15.0
|
||||
60
scripts/chat_datasets.py
Normal file
60
scripts/chat_datasets.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
helper script to parse chat datasets into a usable yaml
|
||||
"""
|
||||
import click
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("dataset", type=str)
|
||||
@click.option("--split", type=str, default="train")
|
||||
def parse_dataset(dataset=None, split="train"):
|
||||
ds_cfg = {}
|
||||
ds_cfg["path"] = dataset
|
||||
ds_cfg["split"] = split
|
||||
ds_cfg["type"] = "chat_template"
|
||||
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
|
||||
|
||||
dataset = load_dataset(dataset, split=split)
|
||||
features = dataset.features
|
||||
feature_keys = features.keys()
|
||||
field_messages = None
|
||||
for key in ["conversation", "conversations", "messages"]:
|
||||
if key in feature_keys:
|
||||
field_messages = key
|
||||
break
|
||||
if not field_messages:
|
||||
raise ValueError(
|
||||
f'No conversation field found in dataset: {", ".join(feature_keys)}'
|
||||
)
|
||||
ds_cfg["field_messages"] = field_messages
|
||||
|
||||
message_fields = features["conversations"][0].keys()
|
||||
message_field_role = None
|
||||
for key in ["from", "role"]:
|
||||
if key in message_fields:
|
||||
message_field_role = key
|
||||
break
|
||||
if not message_field_role:
|
||||
raise ValueError(
|
||||
f'No role field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_role"] = message_field_role
|
||||
|
||||
message_field_content = None
|
||||
for key in ["content", "text", "value"]:
|
||||
if key in message_fields:
|
||||
message_field_content = key
|
||||
break
|
||||
if not message_field_content:
|
||||
raise ValueError(
|
||||
f'No content field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_content"] = message_field_content
|
||||
|
||||
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parse_dataset()
|
||||
23
setup.py
23
setup.py
@@ -30,6 +30,9 @@ def parse_requirements():
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
@@ -49,20 +52,35 @@ def parse_requirements():
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 3):
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.23.post1")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
@@ -91,6 +109,7 @@ setup(
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
|
||||
@@ -30,7 +30,8 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
@@ -54,8 +55,22 @@ LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
AXOLOTL_LOGO = """
|
||||
#@@ #@@ @@# @@#
|
||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
"""
|
||||
|
||||
def print_axolotl_text_art(suffix=None):
|
||||
|
||||
def print_legacy_axolotl_text_art(suffix=None):
|
||||
font = "nancyj"
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
@@ -68,6 +83,13 @@ def print_axolotl_text_art(suffix=None):
|
||||
print_dep_versions()
|
||||
|
||||
|
||||
def print_axolotl_text_art(
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
print(AXOLOTL_LOGO)
|
||||
|
||||
|
||||
def print_dep_versions():
|
||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||
max_len = max(len(pkg) for pkg in packages)
|
||||
@@ -250,7 +272,7 @@ def do_inference_gradio(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -421,6 +443,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
setup_mlflow_env_vars(cfg)
|
||||
|
||||
setup_comet_env_vars(cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -438,7 +462,12 @@ def load_datasets(
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
if (
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
):
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
|
||||
@@ -27,6 +27,7 @@ from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
@@ -70,10 +71,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
with disable_datasets_caching():
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
if parsed_cli_args.download:
|
||||
model_name = parsed_cfg.base_model
|
||||
|
||||
@@ -3,13 +3,11 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -20,6 +18,7 @@ from axolotl.cli import (
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
@@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
def do_train(cfg, cli_args) -> None:
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
@@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -23,7 +23,7 @@ class TrainerCliArgs:
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=5)
|
||||
debug_num_examples: int = field(default=0)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
0
src/axolotl/core/chat/__init__.py
Normal file
0
src/axolotl/core/chat/__init__.py
Normal file
0
src/axolotl/core/chat/format/__init__.py
Normal file
0
src/axolotl/core/chat/format/__init__.py
Normal file
34
src/axolotl/core/chat/format/chatml.py
Normal file
34
src/axolotl/core/chat/format/chatml.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
ChatML transformation functions for MessageContents
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
from .shared import wrap_tools
|
||||
|
||||
|
||||
def format_message(
|
||||
message: Messages,
|
||||
message_index: Optional[int] = None, # pylint: disable=unused-argument
|
||||
) -> Messages:
|
||||
if message.is_chat_formatted:
|
||||
return message
|
||||
|
||||
# prepend the role prefix within a MessageContents to message.content
|
||||
message.content.insert(
|
||||
0,
|
||||
MessageContents(
|
||||
type="text",
|
||||
value=f"<|im_start|>{message.role}\n",
|
||||
weight=0,
|
||||
),
|
||||
)
|
||||
message.content.append(
|
||||
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
|
||||
)
|
||||
message.content.append(MessageContents(type="text", value="\n", weight=0))
|
||||
|
||||
message = wrap_tools(message)
|
||||
|
||||
message.is_chat_formatted = True
|
||||
return message
|
||||
45
src/axolotl/core/chat/format/llama3x.py
Normal file
45
src/axolotl/core/chat/format/llama3x.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Llama 3.x chat formatting functions for MessageContents
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
from .shared import wrap_tools
|
||||
|
||||
|
||||
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
|
||||
if message.is_chat_formatted:
|
||||
return message
|
||||
|
||||
message_role = message.role
|
||||
if message.role == "tool":
|
||||
message_role = "ipython"
|
||||
|
||||
# prepend the role prefix within a MessageContents to message.content
|
||||
message.content.insert(
|
||||
0,
|
||||
MessageContents(
|
||||
type="text",
|
||||
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
|
||||
weight=0,
|
||||
),
|
||||
)
|
||||
|
||||
message.content.append(
|
||||
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
|
||||
)
|
||||
|
||||
message = wrap_tools(message)
|
||||
|
||||
if message_index == 0:
|
||||
message.content.insert(
|
||||
0,
|
||||
MessageContents(
|
||||
type="text",
|
||||
value="<|begin_of_text|>",
|
||||
weight=0,
|
||||
),
|
||||
)
|
||||
|
||||
message.is_chat_formatted = True
|
||||
return message
|
||||
47
src/axolotl/core/chat/format/shared.py
Normal file
47
src/axolotl/core/chat/format/shared.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
shared functions for format transforms
|
||||
"""
|
||||
from axolotl.core.chat.messages import MessageContents, Messages
|
||||
|
||||
|
||||
def wrap_tools(message: Messages):
|
||||
# loop over message.content by index to find tool calls, we need to wrap each with tags,
|
||||
# so be wary of indexing issues when changing the list while iterating.
|
||||
# iterate over the range in reverse order to avoid index shifting
|
||||
for i in range(len(message.content) - 1, -1, -1):
|
||||
if message.content[i].type == "tool_call":
|
||||
# append a </tool_call> MessageContents text tag after
|
||||
message.content.insert(
|
||||
i + 1,
|
||||
MessageContents(
|
||||
type="text", value="</tool_call>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
# make sure the actual tool call content ends with a newline
|
||||
message.content[i].has_newline = True
|
||||
# prepend a <tool_call> MessageContents text tag before
|
||||
message.content.insert(
|
||||
i,
|
||||
MessageContents(
|
||||
type="text", value="<tool_call>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
elif message.content[i].type == "tool_response":
|
||||
# append a </tool_call> MessageContents text tag after
|
||||
message.content.insert(
|
||||
i + 1,
|
||||
MessageContents(
|
||||
type="text", value="</tool_response>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
# make sure the actual tool response content ends with a newline
|
||||
message.content[i].has_newline = True
|
||||
# prepend a <tool_call> MessageContents text tag before
|
||||
message.content.insert(
|
||||
i,
|
||||
MessageContents(
|
||||
type="text", value="<tool_response>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
|
||||
return message
|
||||
230
src/axolotl/core/chat/messages.py
Normal file
230
src/axolotl/core/chat/messages.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
internal message representations of chat messages
|
||||
"""
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class MessageRoles(str, Enum):
|
||||
"""
|
||||
Message roles for the system, user, assistant, and tools
|
||||
"""
|
||||
|
||||
system = "system" # pylint: disable=invalid-name
|
||||
user = "user" # pylint: disable=invalid-name
|
||||
assistant = "assistant" # pylint: disable=invalid-name
|
||||
tool = "tool" # pylint: disable=invalid-name
|
||||
ipython = ( # pylint: disable=invalid-name
|
||||
# for responses from builtin tools
|
||||
"ipython"
|
||||
)
|
||||
|
||||
|
||||
class MessageContentTypes(str, Enum):
|
||||
"""
|
||||
Message content types for text, image, audio, tool calls, and tool responses
|
||||
"""
|
||||
|
||||
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
|
||||
text = "text" # pylint: disable=invalid-name
|
||||
image = "image" # pylint: disable=invalid-name
|
||||
audio = "audio" # pylint: disable=invalid-name
|
||||
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
|
||||
tool_response = "tool_response" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SpecialToken(str, Enum):
|
||||
"""
|
||||
Special tokens for beginning of string and end of string
|
||||
"""
|
||||
|
||||
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
|
||||
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Tool call function with name and arguments
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: dict[str, str]
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""
|
||||
Tool with description, function, and parameters
|
||||
"""
|
||||
|
||||
description: str
|
||||
function: ToolCallFunction
|
||||
parameters: dict[str, str] # .properties
|
||||
|
||||
|
||||
class ToolCallContents(BaseModel):
|
||||
"""
|
||||
Tool call contents with name, arguments, and optional id
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: dict[str, Union[str, int]]
|
||||
id: Optional[str] = None # pylint: disable=invalid-name
|
||||
|
||||
def __str__(self) -> str:
|
||||
data = {"name": self.name, "arguments": self.arguments}
|
||||
if self.id is not None:
|
||||
data["id"] = self.id
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
class ToolResponseContents(BaseModel):
|
||||
"""
|
||||
Tool response contents with name, content, and optional id
|
||||
"""
|
||||
|
||||
name: str
|
||||
content: Union[str, dict[str, Union[str, int, float]]]
|
||||
id: Optional[str] = None # pylint: disable=invalid-name
|
||||
|
||||
def __str__(self) -> str:
|
||||
data = {"name": self.name, "content": self.content}
|
||||
if self.id is not None:
|
||||
data["id"] = self.id
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
class MessageContents(BaseModel):
|
||||
"""
|
||||
Message contents with type, value, metadata, weight, newline, and end of contents
|
||||
"""
|
||||
|
||||
type: Union[str, MessageContentTypes]
|
||||
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
|
||||
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
|
||||
weight: Optional[Union[int, float]] = None
|
||||
has_newline: bool = False
|
||||
eoc: bool = False # end of contents
|
||||
|
||||
def __str__(self) -> str:
|
||||
str_val = str(self.value)
|
||||
if self.has_newline and not str_val.endswith("\n"):
|
||||
str_val += "\n"
|
||||
return str_val
|
||||
|
||||
|
||||
class Messages(BaseModel):
|
||||
"""
|
||||
Messages with role, content, metadata, weight, and chat formatting
|
||||
"""
|
||||
|
||||
role: Union[MessageRoles, str] # allows for arbitrary roles
|
||||
content: List["MessageContents"]
|
||||
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
|
||||
weight: Optional[Union[int, float]] = None
|
||||
is_chat_formatted: bool = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "".join(str(c) for c in self.content)
|
||||
|
||||
def tokenized(
|
||||
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
|
||||
) -> dict[str, List[int]]:
|
||||
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
|
||||
# returns a dictionary mapping w input_ids, attention_mask, and labels
|
||||
input_ids: List[int] = []
|
||||
labels: List[int] = []
|
||||
pending_input_ids: List[int] = []
|
||||
pending_weight = self.weight
|
||||
running_content = ""
|
||||
for _, msg_content in enumerate(self.content):
|
||||
# TODO also handle non-text content types
|
||||
if msg_content.type in [
|
||||
MessageContentTypes.text.value,
|
||||
MessageContentTypes.tool_call.value,
|
||||
MessageContentTypes.tool_response.value,
|
||||
]:
|
||||
running_content += str(msg_content)
|
||||
tok_results = tokenizer(running_content, add_special_tokens=False)
|
||||
tok_input_ids = tok_results["input_ids"]
|
||||
if pending_input_ids:
|
||||
new_pending_inputs = tok_input_ids[
|
||||
len(input_ids) : len(input_ids) + len(pending_input_ids)
|
||||
]
|
||||
if new_pending_inputs != pending_input_ids:
|
||||
# logging.warning("tokenization mismatch from concatenation.")
|
||||
pending_input_ids = new_pending_inputs
|
||||
input_ids.extend(pending_input_ids)
|
||||
if pending_weight:
|
||||
labels.extend(pending_input_ids)
|
||||
else:
|
||||
labels.extend([ignore_index] * len(pending_input_ids))
|
||||
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
|
||||
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
|
||||
input_ids.extend(pending_input_ids)
|
||||
if pending_weight:
|
||||
labels.extend(pending_input_ids)
|
||||
else:
|
||||
labels.extend([ignore_index] * len(pending_input_ids))
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
class Chats(BaseModel):
|
||||
"""
|
||||
top level data structure for chat conversations
|
||||
"""
|
||||
|
||||
conversation: List[Messages]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "".join(str(c) for c in self.conversation)
|
||||
|
||||
def tokenized(
|
||||
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
|
||||
) -> dict[str, List[int]]:
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
labels = []
|
||||
for msg in self.conversation:
|
||||
msg_results = msg.tokenized(tokenizer, ignore_index)
|
||||
input_ids.extend(msg_results["input_ids"])
|
||||
attention_mask.extend(msg_results["attention_mask"])
|
||||
labels.extend(msg_results["labels"])
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
class ChatFormattedChats(Chats):
|
||||
"""
|
||||
Chat formatted chats with formatter and optional train on inputs
|
||||
"""
|
||||
|
||||
formatter: Callable # [[Union[dict, Chats]], Chats]
|
||||
train_on_inputs: bool = False
|
||||
|
||||
def model_post_init(self, __context):
|
||||
for i, msg in enumerate(self.conversation):
|
||||
self.conversation[i] = self.formatter(msg, message_index=i)
|
||||
if self.train_on_inputs:
|
||||
self.conversation[i].weight = 1
|
||||
|
||||
|
||||
class PreferenceChats(BaseModel):
|
||||
"""
|
||||
representation for preference data for chat
|
||||
"""
|
||||
|
||||
prompt: List[Messages]
|
||||
chosen: Messages
|
||||
rejected: Messages
|
||||
0
src/axolotl/core/datasets/__init__.py
Normal file
0
src/axolotl/core/datasets/__init__.py
Normal file
55
src/axolotl/core/datasets/chat.py
Normal file
55
src/axolotl/core/datasets/chat.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
chat dataset module
|
||||
"""
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.core.chat.messages import ChatFormattedChats
|
||||
|
||||
|
||||
class TokenizedChatDataset(Dataset):
|
||||
"""
|
||||
Tokenized chat dataset
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Dataset,
|
||||
model_transform: Union[PreTrainedTokenizer, Callable],
|
||||
*args,
|
||||
message_transform: Optional[Callable] = None,
|
||||
formatter=None,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
def map_fn(ex):
|
||||
if message_transform is not None:
|
||||
ex = message_transform(ex)
|
||||
if formatter is not None:
|
||||
ex = ChatFormattedChats(
|
||||
formatter=formatter,
|
||||
**ex,
|
||||
)
|
||||
else:
|
||||
ex = ChatFormattedChats(
|
||||
**ex,
|
||||
)
|
||||
return ex.tokenized(model_transform)
|
||||
|
||||
process_or_cpu_count: int = (
|
||||
process_count or os.cpu_count() # type: ignore[assignment]
|
||||
)
|
||||
num_proc = min(64, process_or_cpu_count)
|
||||
features = data.features.keys()
|
||||
tokenized_data = data.map(
|
||||
map_fn,
|
||||
num_proc=num_proc,
|
||||
keep_in_memory=keep_in_memory,
|
||||
remove_columns=features,
|
||||
desc="Tokenizing Chats",
|
||||
)
|
||||
super().__init__(tokenized_data.data, *args, **kwargs)
|
||||
0
src/axolotl/core/datasets/transforms/__init__.py
Normal file
0
src/axolotl/core/datasets/transforms/__init__.py
Normal file
150
src/axolotl/core/datasets/transforms/chat_builder.py
Normal file
150
src/axolotl/core/datasets/transforms/chat_builder.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||
"""
|
||||
from typing import Any, Mapping, Union
|
||||
|
||||
|
||||
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
|
||||
train_on_inputs=False,
|
||||
conversations_field: str = "conversations",
|
||||
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
|
||||
message_field_content: Union[str, list[str]] = [
|
||||
"value",
|
||||
"text",
|
||||
"content",
|
||||
], # commonly "content"
|
||||
message_field_training: Union[str, list[str]] = [
|
||||
"train",
|
||||
"weight",
|
||||
], # commonly "weight"
|
||||
):
|
||||
"""Builds a transform that takes a row from the dataset and converts it to a Chat
|
||||
|
||||
Args:
|
||||
train_on_inputs (bool, optional):
|
||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||
Defaults to False.
|
||||
conversations_field (str, optional):
|
||||
The field name of the conversations. Defaults to "conversations".
|
||||
message_field_role (str | list[str], optional):
|
||||
The field name of the role. Defaults to "role".
|
||||
message_field_content (str | list[str], optional):
|
||||
The field name of the message content. Defaults to "content".
|
||||
message_field_training (str | list[str], optional):
|
||||
The field name of the train/weight. Defaults to "weight".
|
||||
|
||||
Returns:
|
||||
Callable:
|
||||
A function that takes a list of conversations and returns a list of messages.
|
||||
"""
|
||||
|
||||
message_field_role = (
|
||||
[message_field_role]
|
||||
if isinstance(message_field_role, str)
|
||||
else message_field_role
|
||||
)
|
||||
message_field_content = (
|
||||
[message_field_content]
|
||||
if isinstance(message_field_content, str)
|
||||
else message_field_content
|
||||
)
|
||||
message_weight_fields = (
|
||||
[message_field_training]
|
||||
if isinstance(message_field_training, str)
|
||||
else message_field_training
|
||||
)
|
||||
|
||||
role_value_mappings = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"human": "user",
|
||||
"assistant": "assistant",
|
||||
"gpt": "assistant",
|
||||
"tool": "tool",
|
||||
"ipython": "ipython",
|
||||
}
|
||||
if train_on_inputs:
|
||||
role_default_weights_mappings = {
|
||||
"system": 1,
|
||||
"user": 1,
|
||||
"assistant": 1,
|
||||
"tool": 1,
|
||||
"ipython": 1,
|
||||
}
|
||||
else:
|
||||
role_default_weights_mappings = {
|
||||
"system": 0,
|
||||
"user": 0,
|
||||
"assistant": 1,
|
||||
"tool": 0,
|
||||
"ipython": 0,
|
||||
}
|
||||
|
||||
def transform_builder(sample: Mapping[str, Any]):
|
||||
if conversations_field not in sample:
|
||||
raise ValueError(f"Field '{conversations_field}' not found in sample.")
|
||||
# if none of the role fields are in the message, raise an error
|
||||
if not any(
|
||||
role in sample[conversations_field][0] for role in message_field_role
|
||||
):
|
||||
raise ValueError("No role field found in message.")
|
||||
role_field = next(
|
||||
role
|
||||
for role in message_field_role
|
||||
if role in sample[conversations_field][0]
|
||||
)
|
||||
if not any(
|
||||
field in sample[conversations_field][0] for field in message_field_content
|
||||
):
|
||||
raise ValueError("No message_content field found in message.")
|
||||
message_content_field = next(
|
||||
field
|
||||
for field in message_field_content
|
||||
if field in sample[conversations_field][0]
|
||||
)
|
||||
if not any(
|
||||
field in sample[conversations_field][0] for field in message_field_training
|
||||
):
|
||||
message_weight_field = None
|
||||
else:
|
||||
message_weight_field = next(
|
||||
field
|
||||
for field in message_weight_fields
|
||||
if field in sample[conversations_field][0]
|
||||
)
|
||||
|
||||
messages = []
|
||||
for message in sample[conversations_field]:
|
||||
role = role_value_mappings[message[role_field]]
|
||||
weight = (
|
||||
int(message[message_weight_field])
|
||||
if message_weight_field
|
||||
else role_default_weights_mappings[role]
|
||||
)
|
||||
|
||||
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
|
||||
if isinstance(message[message_content_field], str):
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"value": message[message_content_field],
|
||||
}
|
||||
],
|
||||
"weight": weight,
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": message[message_content_field],
|
||||
"weight": weight,
|
||||
}
|
||||
)
|
||||
|
||||
return {"conversation": messages}
|
||||
|
||||
return transform_builder
|
||||
@@ -7,6 +7,7 @@ import abc
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -27,7 +28,6 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
@@ -43,12 +43,15 @@ from trl import (
|
||||
KTOTrainer,
|
||||
ORPOConfig,
|
||||
ORPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
@@ -61,7 +64,7 @@ from axolotl.utils.callbacks import (
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
@@ -301,6 +304,13 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
@@ -398,12 +408,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
num_epochs=1,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
super().__init__(*_args, **kwargs)
|
||||
@@ -659,7 +667,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
@@ -667,8 +677,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
@@ -764,7 +784,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
).squeeze(2)
|
||||
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
||||
|
||||
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
||||
def orpo_compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||
inputs,
|
||||
label_pad_token=-100,
|
||||
@@ -891,6 +917,7 @@ class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
@@ -998,18 +1025,32 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
def tokenize_row(
|
||||
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
||||
self,
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = super().tokenize_row(feature, model=model)
|
||||
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
res = super().tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs)
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
@@ -1039,6 +1080,14 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
@@ -1099,26 +1148,49 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def get_callbacks(self) -> List[TrainerCallback]:
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from transformers.integrations.integration_utils import MLflowCallback
|
||||
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.extend(
|
||||
[
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||
MLflowCallback,
|
||||
]
|
||||
)
|
||||
if self.cfg.use_comet and is_comet_available():
|
||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
@@ -1164,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "wandb"
|
||||
@@ -1179,6 +1251,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer, self.tokenizer, "mlflow"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "comet_ml"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
@@ -1203,6 +1280,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
return AxolotlRewardTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -1430,11 +1509,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
|
||||
training_arguments_kwargs["report_to"] = report_to
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||
)
|
||||
if self.cfg.use_wandb:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
else:
|
||||
training_arguments_kwargs["run_name"] = None
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
@@ -1523,8 +1607,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
if self.cfg.chat_template:
|
||||
training_arguments_kwargs["chat_template"] = chat_templates(
|
||||
self.cfg.chat_template
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||
self.cfg.chat_template,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
@@ -1537,6 +1622,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
@@ -1580,10 +1668,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
training_args_cls = (
|
||||
AxolotlTrainingArguments
|
||||
if not self.cfg.reward_model
|
||||
else AxolotlRewardConfig
|
||||
)
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
|
||||
@@ -1605,27 +1696,37 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
if eval_data_collator := self.build_collator(
|
||||
training_args, is_eval=True, **data_collator_kwargs
|
||||
):
|
||||
if not self.cfg.reward_model:
|
||||
trainer_kwargs["eval_data_collator"] = eval_data_collator
|
||||
if not self.cfg.reward_model:
|
||||
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters.keys():
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
else:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||
eval_data_collator=self.build_collator(
|
||||
training_args, is_eval=True, **data_collator_kwargs
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
@@ -1659,9 +1760,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
RewardDataCollatorWithPadding,
|
||||
]
|
||||
]
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
if "max_length" in kwargs:
|
||||
kwargs.pop("max_length")
|
||||
elif use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
@@ -1698,7 +1804,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build_training_arguments(self, total_num_steps):
|
||||
@@ -1863,7 +1969,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
@@ -1875,11 +1981,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
else:
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
tokenizer=self.tokenizer,
|
||||
callbacks=self.get_callbacks(),
|
||||
**dpo_trainer_kwargs,
|
||||
)
|
||||
@@ -1901,11 +2013,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
|
||||
@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
|
||||
|
||||
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
||||
"""
|
||||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import OrderedDict
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@@ -47,7 +48,7 @@ class BasePlugin:
|
||||
Initializes the BasePlugin.
|
||||
"""
|
||||
|
||||
def register(self, cfg):
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Registers the plugin with the given configuration.
|
||||
|
||||
@@ -63,7 +64,7 @@ class BasePlugin:
|
||||
Returns a pydantic model for the plugin's input arguments.
|
||||
"""
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before the model is loaded.
|
||||
|
||||
@@ -74,7 +75,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
|
||||
@@ -86,7 +87,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def pre_lora_load(self, cfg, model):
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before LoRA weights are loaded.
|
||||
|
||||
@@ -98,7 +99,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_lora_load(self, cfg, model):
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after LoRA weights are loaded.
|
||||
|
||||
@@ -110,7 +111,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
|
||||
@@ -122,7 +123,9 @@ class BasePlugin:
|
||||
object: The created optimizer.
|
||||
"""
|
||||
|
||||
def create_lr_scheduler(self, cfg, trainer, optimizer):
|
||||
def create_lr_scheduler(
|
||||
self, cfg, trainer, optimizer
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns a learning rate scheduler.
|
||||
|
||||
@@ -135,7 +138,7 @@ class BasePlugin:
|
||||
object: The created learning rate scheduler.
|
||||
"""
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Adds callbacks to the trainer before training.
|
||||
|
||||
@@ -146,8 +149,11 @@ class BasePlugin:
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Adds callbacks to the trainer after training.
|
||||
|
||||
@@ -158,6 +164,30 @@ class BasePlugin:
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The axolotl configuration
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
|
||||
def load_plugin(plugin_name: str) -> BasePlugin:
|
||||
@@ -204,7 +234,7 @@ class PluginManager:
|
||||
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||
"""
|
||||
|
||||
plugins: List[BasePlugin] = []
|
||||
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
|
||||
|
||||
_instance = None
|
||||
|
||||
@@ -214,7 +244,7 @@ class PluginManager:
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PluginManager, cls).__new__(cls)
|
||||
cls._instance.plugins: List[BasePlugin] = []
|
||||
cls._instance.plugins = collections.OrderedDict()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
@@ -242,7 +272,7 @@ class PluginManager:
|
||||
"""
|
||||
try:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins.append(plugin)
|
||||
self.plugins[plugin_name] = plugin
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
@@ -254,7 +284,7 @@ class PluginManager:
|
||||
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
||||
"""
|
||||
input_args = []
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
input_args_from_plugin = plugin.get_input_args()
|
||||
if input_args_from_plugin is not None:
|
||||
input_args.append(input_args_from_plugin)
|
||||
@@ -270,7 +300,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_model_load(cfg)
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
@@ -284,7 +314,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_model_load(cfg, model)
|
||||
|
||||
def pre_lora_load(self, cfg, model):
|
||||
@@ -298,7 +328,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_lora_load(cfg, model)
|
||||
|
||||
def post_lora_load(self, cfg, model):
|
||||
@@ -312,7 +342,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_lora_load(cfg, model)
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
@@ -326,7 +356,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created optimizer, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
optimizer = plugin.create_optimizer(cfg, trainer)
|
||||
if optimizer is not None:
|
||||
return optimizer
|
||||
@@ -344,7 +374,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created learning rate scheduler, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
||||
if scheduler is not None:
|
||||
return scheduler
|
||||
@@ -362,7 +392,7 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
||||
return callbacks
|
||||
|
||||
@@ -378,6 +408,20 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||
return callbacks
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
"""
|
||||
Calls the post_train_unload method of all registered plugins.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
@@ -18,20 +18,23 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||
It is designed to be performant, correct, and light-weight.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
"""
|
||||
@@ -42,59 +45,31 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.model_config_type == "llama":
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "mistral":
|
||||
from liger_kernel.transformers.model.mistral import (
|
||||
lce_forward as mistral_lce_forward,
|
||||
)
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma":
|
||||
from liger_kernel.transformers.model.gemma import (
|
||||
lce_forward as gemma_lce_forward,
|
||||
)
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma.GemmaRMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
kwargs = {}
|
||||
if "rope" in liger_fn_sig.parameters:
|
||||
kwargs["rope"] = cfg.liger_rope
|
||||
if "cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs[
|
||||
"fused_linear_cross_entropy"
|
||||
] = cfg.liger_fused_linear_cross_entropy
|
||||
if "rms_norm" in liger_fn_sig.parameters:
|
||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||
if "layer_norm" in liger_fn_sig.parameters:
|
||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||
if "geglu" in liger_fn_sig.parameters:
|
||||
kwargs["geglu"] = cfg.liger_glu_activation
|
||||
elif "swiglu" in liger_fn_sig.parameters:
|
||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
||||
|
||||
apply_liger_fn(**kwargs)
|
||||
elif cfg.model_config_type == "jamba":
|
||||
from transformers.models.jamba import modeling_jamba
|
||||
|
||||
@@ -104,30 +79,12 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "qwen2":
|
||||
from liger_kernel.transformers.model.qwen2 import (
|
||||
lce_forward as qwen2_lce_forward,
|
||||
)
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "deepseek_v2":
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoModelForCausalLM
|
||||
@@ -146,44 +103,9 @@ class LigerPlugin(BasePlugin):
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma2":
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma2.Gemma2RMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
logging.warning(
|
||||
"Fused linear cross entropy is not supported for Gemma 2."
|
||||
)
|
||||
|
||||
elif cfg.model_config_type == "phi3":
|
||||
from liger_kernel.transformers.model.phi3 import (
|
||||
lce_forward as phi3_lce_forward,
|
||||
)
|
||||
from transformers.models.phi3 import modeling_phi3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
||||
|
||||
@@ -15,9 +15,12 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||
|
||||
|
||||
class LigerArgs(BaseModel):
|
||||
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
if data.get("liger_swiglu") is not None:
|
||||
if data.get("liger_glu_activation") is not None:
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
||||
)
|
||||
|
||||
LOG.warning(
|
||||
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
||||
"Please use 'liger_glu_activation' instead."
|
||||
)
|
||||
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||
return data
|
||||
|
||||
13
src/axolotl/integrations/lm_eval/README.md
Normal file
13
src/axolotl/integrations/lm_eval/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# LM Eval Harness
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.lm_eval.LMEvalPlugin
|
||||
|
||||
lm_eval_tasks:
|
||||
- gsm8k
|
||||
- hellaswag
|
||||
- arc_easy
|
||||
```
|
||||
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
42
src/axolotl/integrations/lm_eval/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from datetime import datetime
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
|
||||
class LMEvalPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for LM Evaluation Harness integraton with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
tasks = ",".join(cfg.lm_eval_tasks)
|
||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||
output_path = cfg.output_dir
|
||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
subprocess.run( # nosec
|
||||
[
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||
"--tasks",
|
||||
tasks,
|
||||
"--batch_size",
|
||||
str(cfg.lm_eval_batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
15
src/axolotl/integrations/lm_eval/args.py
Normal file
15
src/axolotl/integrations/lm_eval/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Module for handling lm eval harness input arguments.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LMEvalArgs(BaseModel):
|
||||
"""
|
||||
Input args for lm eval harness
|
||||
"""
|
||||
|
||||
lm_eval_tasks: List[str] = []
|
||||
lm_eval_batch_size: Optional[int] = 8
|
||||
@@ -22,7 +22,6 @@ from transformers.models.llama.modeling_llama import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
|
||||
@@ -44,7 +43,19 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def is_xformers_available() -> bool:
|
||||
try:
|
||||
import xformers # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def is_xformers_swiglu_available() -> bool:
|
||||
if not is_xformers_available():
|
||||
return False
|
||||
|
||||
from xformers.ops.common import get_xformers_operator
|
||||
|
||||
try:
|
||||
@@ -57,6 +68,11 @@ def is_xformers_swiglu_available() -> bool:
|
||||
|
||||
|
||||
def replace_llama_mlp_with_swiglu(model):
|
||||
if is_xformers_swiglu_available():
|
||||
from axolotl.monkeypatch.xformers_ import FusedMLP
|
||||
else:
|
||||
raise RuntimeError("xformers SwiGLU not available for this environment")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaMLP):
|
||||
mlp = FusedMLP(
|
||||
@@ -181,49 +197,6 @@ class FusedAttention(LlamaAttention):
|
||||
set_module_name(model, name, new_attn)
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
|
||||
@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: list[str],
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: List[str],
|
||||
prune_ratio: float = 0.9,
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
|
||||
@@ -16,26 +16,6 @@ from transformers.models.llama.modeling_llama import (
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||
|
||||
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
"""
|
||||
|
||||
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
)
|
||||
"""
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
@@ -80,12 +60,6 @@ def get_forward_code() -> str:
|
||||
return forward
|
||||
|
||||
|
||||
def check_cel_is_patchable() -> bool:
|
||||
forward = get_forward_code()
|
||||
forward, _ = detab_code(forward)
|
||||
return ORIGINAL_CEL_CODE in forward
|
||||
|
||||
|
||||
def get_self_attn_code() -> str:
|
||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||
return forward
|
||||
@@ -98,48 +72,31 @@ def check_self_attn_is_patchable() -> bool:
|
||||
|
||||
|
||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
|
||||
|
||||
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int, # pylint: disable=unused-argument
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
|
||||
)
|
||||
return loss
|
||||
|
||||
if model_type == "llama":
|
||||
forward = get_forward_code()
|
||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||
forward, _ = detab_code(forward)
|
||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||
from transformers.loss import loss_utils
|
||||
|
||||
forward = forward.replace(
|
||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||
)
|
||||
forward = forward.replace(
|
||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||
"",
|
||||
)
|
||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||
forward = forward.replace(
|
||||
"def forward(",
|
||||
"def fast_cross_entropy_loss_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||
globals(),
|
||||
)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
|
||||
else:
|
||||
raise ValueError("Unsupported model type")
|
||||
|
||||
|
||||
51
src/axolotl/monkeypatch/xformers_/__init__.py
Normal file
51
src/axolotl/monkeypatch/xformers_/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from axolotl.monkeypatch.utils import set_module_name
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
@@ -11,6 +11,10 @@ LOG = logging.getLogger("axolotl.prompt_strategies")
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
try:
|
||||
if strategy == "messages":
|
||||
from .messages import load as messages_load
|
||||
|
||||
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
@@ -31,4 +35,5 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
return None
|
||||
raise exc
|
||||
return None
|
||||
|
||||
10
src/axolotl/prompt_strategies/bradley_terry/README.md
Normal file
10
src/axolotl/prompt_strategies/bradley_terry/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
### example yaml
|
||||
|
||||
```yaml
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
```
|
||||
35
src/axolotl/prompt_strategies/bradley_terry/__init__.py
Normal file
35
src/axolotl/prompt_strategies/bradley_terry/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Module to load prompt strategies."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
|
||||
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
# pylint: disable=duplicate-code
|
||||
try:
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(
|
||||
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
|
||||
)
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||
else:
|
||||
sig = inspect.signature(func)
|
||||
if "ds_cfg" in sig.parameters:
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
return None
|
||||
102
src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Normal file
102
src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Bradley-Terry model with chat template prompt strategy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplatePrompter,
|
||||
ChatTemplateStrategy,
|
||||
)
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
"""
|
||||
Bradley-Terry reward model pairwise chat template prompt strategy.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
"""
|
||||
|
||||
:param prompt: the actual row of data from the underlying dataset
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "assistant", "content": prompt["rejected"]}
|
||||
)
|
||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
return {
|
||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||
"labels_chosen": 1.0,
|
||||
"input_ids_rejected": rejected_tokenized["input_ids"],
|
||||
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
||||
"labels_rejected": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", None
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1
|
||||
if not cfg.reward_model
|
||||
else cfg.sequence_len,
|
||||
}
|
||||
|
||||
strategy_params = {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", []),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", None),
|
||||
}
|
||||
|
||||
strategy = BTChatTemplateStrategy(
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
27
src/axolotl/prompt_strategies/bradley_terry/llama3.py
Normal file
27
src/axolotl/prompt_strategies/bradley_terry/llama3.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
|
||||
"""
|
||||
|
||||
|
||||
def icr(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
prompt = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -403,11 +403,16 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_messages = ds_cfg.get("field_messages", "messages")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
@@ -30,6 +31,12 @@ def default(
|
||||
role_map[source] = target
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
chat_template_string = get_chat_template(
|
||||
user_choice=chat_template_choice,
|
||||
jinja_template=chat_template_jinja,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
messages = sample[field_messages]
|
||||
messages = [
|
||||
{
|
||||
@@ -46,28 +53,29 @@ def default(
|
||||
"role": role_map[sample[field_rejected][field_message_role]],
|
||||
"content": sample[field_rejected][field_message_content],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
result = {}
|
||||
result["prompt"] = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
result["chosen"] = tokenizer.apply_chat_template(
|
||||
[chosen],
|
||||
[dummy_user_message, chosen],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||
|
||||
result["rejected"] = tokenizer.apply_chat_template(
|
||||
[rejected],
|
||||
[dummy_user_message, rejected],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||
|
||||
34
src/axolotl/prompt_strategies/messages/__init__.py
Normal file
34
src/axolotl/prompt_strategies/messages/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Module to load message prompt strategies."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg, processor=None):
|
||||
try:
|
||||
strategy = ds_cfg.get("input_transform", "chat")
|
||||
# pylint: disable=duplicate-code
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(
|
||||
f".{strategy}", "axolotl.prompt_strategies.messages"
|
||||
)
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
sig = inspect.signature(func)
|
||||
if "ds_cfg" in sig.parameters:
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
if "processor" in sig.parameters:
|
||||
load_kwargs["processor"] = processor
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
84
src/axolotl/prompt_strategies/messages/chat.py
Normal file
84
src/axolotl/prompt_strategies/messages/chat.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Chat dataset wrapping strategy for new internal messages representations
|
||||
"""
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from axolotl.core.datasets.chat import TokenizedChatDataset
|
||||
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
|
||||
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||
|
||||
|
||||
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
|
||||
"""
|
||||
Chat dataset wrapping strategy for new internal messages representations
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor,
|
||||
message_transform=None,
|
||||
formatter=None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
:param processor: tokenizer or image processor
|
||||
:param kwargs:
|
||||
"""
|
||||
self.processor = processor
|
||||
self.dataset = None
|
||||
self.message_transform = message_transform
|
||||
self.formatter = formatter
|
||||
|
||||
def wrap_dataset(
|
||||
self,
|
||||
dataset,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
self.dataset = TokenizedChatDataset(
|
||||
dataset,
|
||||
message_transform=self.message_transform,
|
||||
model_transform=self.processor,
|
||||
formatter=self.formatter,
|
||||
process_count=process_count,
|
||||
keep_in_memory=keep_in_memory,
|
||||
)
|
||||
return self.dataset
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
field_messages = ds_cfg.get("field_messages")
|
||||
message_field_role = ds_cfg.get("message_field_role")
|
||||
message_field_content = ds_cfg.get("message_field_content")
|
||||
message_field_training = ds_cfg.get("message_field_training")
|
||||
|
||||
builder_kwargs = {}
|
||||
if field_messages:
|
||||
builder_kwargs["conversations_field"] = field_messages
|
||||
if message_field_role:
|
||||
builder_kwargs["message_field_role"] = message_field_role
|
||||
if message_field_content:
|
||||
builder_kwargs["message_field_content"] = message_field_content
|
||||
if message_field_training:
|
||||
builder_kwargs["message_field_training"] = message_field_training
|
||||
|
||||
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
|
||||
format_message = (
|
||||
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
|
||||
)
|
||||
if chat_template == "chatml":
|
||||
from axolotl.core.chat.format.chatml import format_message # noqa F811
|
||||
if chat_template.startswith("llama3"):
|
||||
from axolotl.core.chat.format.llama3x import format_message # noqa F811
|
||||
message_transform: Callable = chat_message_transform_builder(
|
||||
train_on_inputs=ds_cfg.get("train_on_inputs", False),
|
||||
**builder_kwargs,
|
||||
)
|
||||
strategy = ChatMessageDatasetWrappingStrategy(
|
||||
tokenizer, message_transform=message_transform, formatter=format_message
|
||||
)
|
||||
|
||||
return strategy
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
@@ -28,18 +28,13 @@ def load(
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
"""
|
||||
|
||||
chat_template = chat_templates("chatml")
|
||||
if ds_cfg and "chat_template" in ds_cfg:
|
||||
chat_template = ds_cfg["chat_template"]
|
||||
try:
|
||||
chat_template = chat_templates(chat_template)
|
||||
except ValueError:
|
||||
pass
|
||||
tokenizer.chat_template = chat_template
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
tokenizer.chat_template = chat_template_string
|
||||
|
||||
return ORPOTokenizingStrategy(
|
||||
ORPOPrompter(chat_template, tokenizer),
|
||||
ORPOPrompter(chat_template_string, tokenizer),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -248,28 +243,30 @@ class ORPOPrompter(Prompter):
|
||||
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
dataset_parser = ORPODatasetParsingStrategy()
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
res = {}
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
res["prompt"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_str_len = len(res["prompt"])
|
||||
res["chosen"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
res["rejected"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
|
||||
|
||||
@@ -61,6 +61,9 @@ def build_loader(
|
||||
default_conversation: Optional[str] = None,
|
||||
):
|
||||
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
LOG.warning(
|
||||
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
|
||||
)
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
|
||||
@@ -30,6 +30,12 @@ class InvalidDataException(Exception):
|
||||
"""
|
||||
|
||||
|
||||
class DatasetWrappingStrategy(abc.ABC):
|
||||
"""
|
||||
Abstract class for wrapping datasets for Chat Messages
|
||||
"""
|
||||
|
||||
|
||||
class PromptTokenizingStrategy(abc.ABC):
|
||||
"""
|
||||
Abstract class for tokenizing strategies
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
@@ -97,12 +96,11 @@ def train(
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
# we wait unitl the last possible moment to setup Accelerator
|
||||
Accelerator()
|
||||
model, peft_config = load_model(
|
||||
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||
)
|
||||
model.generation_config.do_sample = True
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
@@ -262,8 +260,10 @@ def train(
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
try:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
except AttributeError:
|
||||
trainer.create_model_card(
|
||||
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
||||
)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""
|
||||
Basic utils for Axolotl
|
||||
"""
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
references=[[r] for r in references],
|
||||
predictions=predictions,
|
||||
)
|
||||
scores[metric_name] = score
|
||||
scores["eval_" + metric_name] = score
|
||||
return scores
|
||||
|
||||
def predict_with_generate():
|
||||
@@ -747,6 +747,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
||||
artifact_file="PredictionsVsGroundTruth.json",
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
elif logger == "comet_ml" and is_comet_available():
|
||||
import comet_ml
|
||||
|
||||
experiment = comet_ml.get_running_experiment()
|
||||
if experiment:
|
||||
experiment.log_table(
|
||||
f"{name} - Predictions vs Ground Truth.csv",
|
||||
pd.DataFrame(table_data),
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
log_table_from_dataloader("Eval", eval_dataloader)
|
||||
|
||||
43
src/axolotl/utils/callbacks/comet_.py
Normal file
43
src/axolotl/utils/callbacks/comet_.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Comet module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import comet_ml
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to comet"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
comet_experiment = comet_ml.start(source="axolotl")
|
||||
comet_experiment.log_other("Created from", "axolotl")
|
||||
comet_experiment.log_asset(
|
||||
self.axolotl_config_path,
|
||||
file_name="axolotl-config",
|
||||
)
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the Comet Experiment under assets."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
|
||||
return control
|
||||
File diff suppressed because one or more lines are too long
@@ -4,6 +4,7 @@ Collators for multi-modal chat messages and packing
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
from transformers.utils import PaddingStrategy
|
||||
@@ -52,7 +53,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
images = [example["images"] for example in examples]
|
||||
images = [
|
||||
Image.open(example["images"])
|
||||
if isinstance(example["images"], str)
|
||||
else example["images"]
|
||||
for example in examples
|
||||
]
|
||||
|
||||
if max_images > 0:
|
||||
images = [img_batch[:max_images] for img_batch in images]
|
||||
|
||||
93
src/axolotl/utils/comet_.py
Normal file
93
src/axolotl/utils/comet_.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Module for wandb utilities"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.comet_")
|
||||
|
||||
COMET_ENV_MAPPING_OVERRIDE = {
|
||||
"comet_mode": "COMET_START_MODE",
|
||||
"comet_online": "COMET_START_ONLINE",
|
||||
}
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
|
||||
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
|
||||
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
|
||||
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
|
||||
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
|
||||
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
|
||||
"auto_log_co2": "COMET_AUTO_LOG_CO2",
|
||||
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
|
||||
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
|
||||
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
|
||||
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
|
||||
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
|
||||
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
|
||||
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
|
||||
"log_code": "COMET_AUTO_LOG_CODE",
|
||||
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
|
||||
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
|
||||
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
|
||||
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
|
||||
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
|
||||
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
|
||||
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
|
||||
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
|
||||
"log_graph": "COMET_AUTO_LOG_GRAPH",
|
||||
"name": "COMET_START_EXPERIMENT_NAME",
|
||||
"offline_directory": "COMET_OFFLINE_DIRECTORY",
|
||||
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
|
||||
"tags": "COMET_START_EXPERIMENT_TAGS",
|
||||
}
|
||||
|
||||
|
||||
def python_value_to_environ_value(python_value):
|
||||
if isinstance(python_value, bool):
|
||||
if python_value is True:
|
||||
return "true"
|
||||
|
||||
return "false"
|
||||
|
||||
if isinstance(python_value, int):
|
||||
return str(python_value)
|
||||
|
||||
if isinstance(python_value, list): # Comet only have one list of string parameter
|
||||
return ",".join(map(str, python_value))
|
||||
|
||||
return python_value
|
||||
|
||||
|
||||
def setup_comet_env_vars(cfg: DictDefault):
|
||||
# TODO, we need to convert Axolotl configuration to environment variables
|
||||
# as Transformers integration are call first and would create an
|
||||
# Experiment first
|
||||
|
||||
for key in cfg.keys():
|
||||
if key.startswith("comet_") and key != "comet_experiment_config":
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value is not None and value != "":
|
||||
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[env_variable_name] = final_value
|
||||
|
||||
if cfg.comet_experiment_config:
|
||||
for key, value in cfg.comet_experiment_config.items():
|
||||
if value is not None and value != "":
|
||||
config_env_variable_name = (
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
|
||||
)
|
||||
|
||||
if config_env_variable_name is None:
|
||||
LOG.warning(
|
||||
f"Unknown Comet Experiment Config name {key}, ignoring it"
|
||||
)
|
||||
continue
|
||||
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[config_env_variable_name] = final_value
|
||||
|
||||
# Enable comet if project name is present
|
||||
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
|
||||
cfg.use_comet = True
|
||||
@@ -228,6 +228,7 @@ def normalize_cfg_datasets(cfg):
|
||||
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].chat_template = cfg.chat_template
|
||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||
|
||||
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
|
||||
@@ -8,9 +8,16 @@ import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
StringConstraints,
|
||||
conlist,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
@@ -21,6 +28,38 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
simpo = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
"""configurations that are deprecated"""
|
||||
|
||||
@@ -102,14 +141,22 @@ class SFTDataset(BaseModel):
|
||||
path: Optional[str] = None
|
||||
split: Optional[str] = None
|
||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||
input_transform: Optional[str] = None
|
||||
shards: Optional[int] = None
|
||||
conversation: Optional[str] = None
|
||||
chat_template: Optional[str] = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
str,
|
||||
]
|
||||
] = None
|
||||
chat_template_jinja: Optional[str] = None
|
||||
data_files: Optional[Union[str, List[str]]] = None
|
||||
input_format: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
ds_type: Optional[str] = None
|
||||
train_on_split: Optional[str] = None
|
||||
|
||||
field: Optional[str] = None
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
@@ -120,11 +167,31 @@ class SFTDataset(BaseModel):
|
||||
message_field_training_detail: Optional[str] = None
|
||||
roles_to_train: Optional[List[str]] = None
|
||||
train_on_eos: Optional[str] = None
|
||||
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
drop_system_message: Optional[bool] = None
|
||||
|
||||
trust_remote_code: Optional[bool] = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
# Set chat_template to tokenizer_default if not set
|
||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
# If chat_template_jinja is set, set chat_template to jinja
|
||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.jinja
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
@@ -146,6 +213,7 @@ class DPODataset(BaseModel):
|
||||
split: Optional[str] = None
|
||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
class UserDefinedKTOType(BaseModel):
|
||||
@@ -167,32 +235,7 @@ class KTODataset(BaseModel):
|
||||
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
trust_remote_code: Optional[bool] = False
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
simpo = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
inst = "inst" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
@@ -444,6 +487,7 @@ class MLFlowConfig(BaseModel):
|
||||
use_mlflow: Optional[bool] = None
|
||||
mlflow_tracking_uri: Optional[str] = None
|
||||
mlflow_experiment_name: Optional[str] = None
|
||||
mlflow_run_name: Optional[str] = None
|
||||
hf_mlflow_log_artifacts: Optional[bool] = None
|
||||
|
||||
|
||||
@@ -489,6 +533,19 @@ class WandbConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class CometConfig(BaseModel):
|
||||
"""Comet configuration subset"""
|
||||
|
||||
use_comet: Optional[bool] = None
|
||||
comet_api_key: Optional[str] = None
|
||||
comet_workspace: Optional[str] = None
|
||||
comet_project_name: Optional[str] = None
|
||||
comet_experiment_key: Optional[str] = None
|
||||
comet_mode: Optional[str] = None
|
||||
comet_online: Optional[bool] = None
|
||||
comet_experiment_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
@@ -509,6 +566,7 @@ class AxolotlInputConfig(
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RemappedParameters,
|
||||
@@ -526,8 +584,10 @@ class AxolotlInputConfig(
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
mean_resizing_embeddings: Optional[bool] = False
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
reward_model: Optional[bool] = None
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
@@ -694,7 +754,13 @@ class AxolotlInputConfig(
|
||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||
low_cpu_mem_usage: Optional[bool] = None
|
||||
|
||||
chat_template: Optional[ChatTemplate] = None
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
]
|
||||
] = None
|
||||
chat_template_jinja: Optional[str] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
fix_untrained_tokens: Optional[bool] = None
|
||||
@@ -803,6 +869,23 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
# If chat_template_jinja is set, set chat_template to jinja
|
||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.jinja
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_wo_flash(cls, data):
|
||||
@@ -833,6 +916,17 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def hint_reward_model_pad(cls, data):
|
||||
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using reward_model"
|
||||
)
|
||||
if data.get("pad_to_sequence_len") is None:
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gas_bsz(cls, data):
|
||||
@@ -966,6 +1060,26 @@ class AxolotlInputConfig(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (
|
||||
data.get("evals_per_epoch") or data.get("eval_steps")
|
||||
):
|
||||
raise ValueError(
|
||||
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_test_datasets_bench(cls, data):
|
||||
if (
|
||||
data.get("do_bench_eval")
|
||||
and not data.get("test_datasets")
|
||||
and not data.get("val_set_size")
|
||||
):
|
||||
LOG.warning(
|
||||
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
||||
)
|
||||
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg):
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_cfg["path"],
|
||||
split=ds_cfg["split"],
|
||||
revision=ds_cfg.get("revision", None),
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
@@ -19,10 +21,12 @@ from transformers import PreTrainedTokenizerBase
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
DatasetWrappingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
@@ -51,6 +55,28 @@ from axolotl.utils.trainer import (
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except (
|
||||
requests.exceptions.ReadTimeout,
|
||||
requests.exceptions.ConnectionError,
|
||||
) as exc:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise exc
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
@@ -242,6 +268,7 @@ def load_tokenized_prepared_datasets(
|
||||
name=config_dataset.name,
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
@@ -346,6 +373,7 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=False,
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
@@ -380,6 +408,7 @@ def load_tokenized_prepared_datasets(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=config_dataset.data_files,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
fp = []
|
||||
@@ -389,6 +418,7 @@ def load_tokenized_prepared_datasets(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -433,8 +463,8 @@ def load_tokenized_prepared_datasets(
|
||||
config_dataset=config_dataset,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset=ds,
|
||||
d_base_type=d_base_type,
|
||||
dataset=ds,
|
||||
d_prompt_style=d_prompt_style,
|
||||
processor=processor,
|
||||
)
|
||||
@@ -454,7 +484,7 @@ def load_tokenized_prepared_datasets(
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged datasets")
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
if cfg.sample_packing and not cfg.skip_prepare_dataset:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
@@ -569,7 +599,7 @@ def get_dataset_wrapper(
|
||||
d_base_type,
|
||||
dataset,
|
||||
d_prompt_style=None,
|
||||
processor=None,
|
||||
processor=None, # pylint: disable=unused-argument
|
||||
):
|
||||
dataset_wrapper = None
|
||||
dataset_prompter = None
|
||||
@@ -604,8 +634,10 @@ def get_dataset_wrapper(
|
||||
)
|
||||
elif cfg.skip_prepare_dataset:
|
||||
dataset_wrapper = dataset
|
||||
elif ds_strategy := load(
|
||||
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
|
||||
elif ds_strategy := config_dataset.type.startswith(
|
||||
"bradley_terry"
|
||||
) and bradley_terry_load(
|
||||
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
||||
):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
@@ -613,6 +645,18 @@ def get_dataset_wrapper(
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
)
|
||||
elif ds_strategy := load(
|
||||
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
|
||||
):
|
||||
if isinstance(ds_strategy, DatasetWrappingStrategy):
|
||||
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||
else:
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
)
|
||||
elif d_base_type == "alpaca":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
|
||||
@@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault):
|
||||
# Enable mlflow if experiment name is present
|
||||
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
|
||||
cfg.use_mlflow = True
|
||||
|
||||
# Enable logging hf artifacts in mlflow if value is truthy
|
||||
if cfg.hf_mlflow_log_artifacts is True:
|
||||
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -133,6 +133,8 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
self.len_across_ranks = None
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -195,15 +197,14 @@ class MultipackBatchSampler(BatchSampler):
|
||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||
return math.floor(0.998 * min(estimates))
|
||||
|
||||
min_len_batches = reduce_and_broadcast(
|
||||
lambda: num,
|
||||
calc_min_len,
|
||||
)
|
||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
||||
return min_len_batches
|
||||
|
||||
def __len__(self):
|
||||
len_batches = self.num_batches()
|
||||
return self.gather_len_batches(len_batches)
|
||||
if not self.len_across_ranks:
|
||||
len_batches = self.num_batches()
|
||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
||||
return self.len_across_ranks
|
||||
|
||||
def _len_est(self):
|
||||
efficiency = (
|
||||
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import set_caching_enabled
|
||||
from datasets import disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
@contextmanager
|
||||
def disable_datasets_caching():
|
||||
try:
|
||||
set_caching_enabled(False)
|
||||
disable_caching()
|
||||
yield
|
||||
finally:
|
||||
set_caching_enabled(True)
|
||||
enable_caching()
|
||||
|
||||
|
||||
def add_position_ids(sample):
|
||||
@@ -306,7 +306,11 @@ def process_pretraining_datasets_for_packing(
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if not cfg.total_num_tokens and not cfg.skip_prepare_dataset:
|
||||
if (
|
||||
not cfg.total_num_tokens
|
||||
and not cfg.skip_prepare_dataset
|
||||
and not cfg.reward_model
|
||||
):
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
@@ -323,6 +327,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
not skip_estimates
|
||||
and not cfg.total_supervised_tokens
|
||||
and not cfg.skip_prepare_dataset
|
||||
and not cfg.reward_model
|
||||
):
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
|
||||
0
tests/core/chat/__init__.py
Normal file
0
tests/core/chat/__init__.py
Normal file
0
tests/core/chat/format/__init__.py
Normal file
0
tests/core/chat/format/__init__.py
Normal file
197
tests/core/chat/test_messages.py
Normal file
197
tests/core/chat/test_messages.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Tests for the chat messages module
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers import AddedToken, AutoTokenizer
|
||||
|
||||
from axolotl.core.chat.format.chatml import format_message
|
||||
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="llama_tokenizer")
|
||||
def llama_tokenizer_fixture():
|
||||
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
||||
def llama_tokenizer_w_chatml(llama_tokenizer):
|
||||
llama_tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
llama_tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return llama_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="chat_msgs")
|
||||
def chat_msgs_fixture():
|
||||
return {
|
||||
"conversation": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "value": "You are a helpful assistant."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "value": "What is today's stock price of Apple?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_call",
|
||||
"value": {
|
||||
"name": "get_date",
|
||||
"arguments": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"value": {
|
||||
"name": "get_stock_price",
|
||||
"arguments": {"symbol": "AAPL"},
|
||||
},
|
||||
},
|
||||
],
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_response",
|
||||
"value": {
|
||||
"name": "get_date",
|
||||
"content": {"date": "2024-09-09"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool_response",
|
||||
"value": {
|
||||
"name": "get_stock_price",
|
||||
"content": {"symbol": "AAPL", "price": 123.45},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"value": "The stock price of Apple is $123.45.\n",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"value": "<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"value": "The stock price of Apple on September 9, 2024 is $123.45.",
|
||||
},
|
||||
],
|
||||
"weight": 1,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class TestMessagesCase:
|
||||
"""
|
||||
Test cases for the chat messages module
|
||||
"""
|
||||
|
||||
def test_tool_call_stringify(self, chat_msgs):
|
||||
chat_msgs_as_obj = Chats(**chat_msgs)
|
||||
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
|
||||
chat_msgs_as_obj.conversation[2].content[1].value
|
||||
)
|
||||
|
||||
def test_chatml_formatted_wrapper(self, chat_msgs):
|
||||
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||
target_chatml = """<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
What is today's stock price of Apple?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_date", "arguments": {}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
|
||||
</tool_call>
|
||||
<|im_end|>
|
||||
<|im_start|>tool
|
||||
<tool_response>
|
||||
{"name": "get_date", "content": {"date": "2024-09-09"}}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The stock price of Apple is $123.45.
|
||||
<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
|
||||
assert target_chatml == str(chat_msg_formatted)
|
||||
|
||||
def test_chatml_formatting_tool_call(self, chat_msgs):
|
||||
chat_msgs_as_obj = Chats(**chat_msgs)
|
||||
target_chatml_turn2 = """<|im_start|>assistant\n<tool_call>\n{"name": "get_date", "arguments": {}}\n</tool_call>\n<tool_call>\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n</tool_call>\n<|im_end|>\n"""
|
||||
assert target_chatml_turn2 == str(
|
||||
format_message(chat_msgs_as_obj.conversation[2])
|
||||
)
|
||||
|
||||
def test_train_labels(self, chatml_tokenizer, chat_msgs):
|
||||
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||
tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
|
||||
# fmt: off
|
||||
target_labels = [
|
||||
-100, -100, -100, # role
|
||||
27, 14506, 13735, 397, 5018, 609, 794,
|
||||
330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
|
||||
14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
|
||||
330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
|
||||
794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
|
||||
128256, # <|im_end|>
|
||||
-100 # trailing newline
|
||||
]
|
||||
# fmt: on
|
||||
assert tokenized["labels"] == target_labels
|
||||
|
||||
def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
|
||||
# also test if indivudal contents are set not to train
|
||||
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||
tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
|
||||
# fmt: off
|
||||
target_labels = [
|
||||
-100, -100, -100, # role
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
|
||||
27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
|
||||
315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
|
||||
5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
|
||||
8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
|
||||
4513, 13, 1774, 13,
|
||||
128256, # <|im_end|>
|
||||
-100, # trailing newline
|
||||
]
|
||||
# fmt: on
|
||||
assert tokenized["labels"] == target_labels
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
155
tests/e2e/multigpu/test_eval.py
Normal file
155
tests/e2e/multigpu/test_eval.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
E2E tests for multigpu eval
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
class TestMultiGPUEval(unittest.TestCase):
|
||||
"""
|
||||
Test case for MultiGPU Eval Sample Packing
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_eval_sample_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"strict": False,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"loss_watchdog_threshold": 5.0,
|
||||
"loss_watchdog_patience": 3,
|
||||
"bf16": "auto",
|
||||
"warmup_steps": 1,
|
||||
"evals_per_epoch": 2,
|
||||
"eval_max_new_tokens": 128,
|
||||
"saves_per_epoch": 1,
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_eval(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"strict": False,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"loss_watchdog_threshold": 5.0,
|
||||
"loss_watchdog_patience": 3,
|
||||
"bf16": "auto",
|
||||
"warmup_steps": 1,
|
||||
"evals_per_epoch": 2,
|
||||
"eval_max_new_tokens": 128,
|
||||
"saves_per_epoch": 1,
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
@@ -14,11 +14,13 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import is_hopper, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_model():
|
||||
@@ -57,7 +59,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -114,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 50,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -142,6 +144,146 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||
@with_temp_dir
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fsdp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -163,7 +305,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -229,7 +371,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -271,7 +413,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.skip("disabled due to upstream issue")
|
||||
@with_temp_dir
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -280,6 +421,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"adapter": "qlora",
|
||||
"mean_resizing_embeddings": True,
|
||||
"load_in_4bit": True,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
@@ -295,7 +437,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|end_of_text|>",
|
||||
"pad_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -305,7 +447,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -346,3 +488,115 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ds_zero3_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import require_torch_2_1_1, with_temp_dir
|
||||
from ..utils import require_torch_2_3_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
Test case for Llama models using 4d attention with multipack
|
||||
"""
|
||||
|
||||
@require_torch_2_1_1
|
||||
@require_torch_2_3_1
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -1,22 +1,12 @@
|
||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import (
|
||||
check_cel_is_patchable,
|
||||
check_self_attn_is_patchable,
|
||||
)
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
|
||||
class TestUnslothIntegration(unittest.TestCase):
|
||||
"""Unsloth monkeypatch integration tests."""
|
||||
|
||||
def test_is_cel_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
check_cel_is_patchable(),
|
||||
"HF transformers loss code has changed and isn't patchable",
|
||||
)
|
||||
|
||||
def test_is_self_attn_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
|
||||
95
tests/e2e/test_load_model.py
Normal file
95
tests/e2e/test_load_model.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Module for testing ModelLoader."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="temp_dir")
|
||||
def fixture_temp_dir():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
class TestLoadModelUtils:
|
||||
"""
|
||||
Testing module testing ModelLoader.
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
# load config
|
||||
self.cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"tokenizer_config": "JackFram/llama-68m",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": False,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
|
||||
@pytest.mark.parametrize(
|
||||
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
|
||||
)
|
||||
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
|
||||
def test_convert_embedding_modules_dtype(
|
||||
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
):
|
||||
self.cfg.output_dir = temp_dir
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
||||
self.model_loader.model, _ = load_model(
|
||||
self.cfg,
|
||||
self.model_loader.tokenizer,
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
self.model_loader.convert_embedding_modules_dtype(
|
||||
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
)
|
||||
for name, module in self.model_loader.model.named_modules():
|
||||
if (
|
||||
"norm" in name
|
||||
or (before_kbit_train_or_finetune and name.endswith(".gate"))
|
||||
or (
|
||||
any(m in name for m in embedding_modules)
|
||||
and hasattr(module, "weight")
|
||||
)
|
||||
):
|
||||
for _, param in module.named_parameters():
|
||||
assert param.dtype == dist_dtype
|
||||
74
tests/e2e/test_packing_loss.py
Normal file
74
tests/e2e/test_packing_loss.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
E2E tests for packed training
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from tbparse import SummaryReader
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import most_recent_subdir, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPackedLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Packed training of llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_loss_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
74
tests/e2e/test_reward_model_llama.py
Normal file
74
tests/e2e/test_reward_model_llama.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
E2E tests for reward model lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestRewardModelLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama reward models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_rm_fft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"model_type": "AutoModelForSequenceClassification",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"chat_template": "alpaca",
|
||||
"reward_model": True,
|
||||
"sequence_len": 1024,
|
||||
"pad_to_sequence_len": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||
"type": "bradley_terry.chat_template",
|
||||
},
|
||||
],
|
||||
"remove_unused_columns": False,
|
||||
"max_steps": 10,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"gradient_checkpointing": True,
|
||||
"warmup_ratio": 0.1,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
@@ -9,6 +9,8 @@ from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
@@ -35,13 +37,18 @@ def most_recent_subdir(path):
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_1_1(test_case):
|
||||
def require_torch_2_3_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.1.1
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
"""
|
||||
|
||||
def is_min_2_1_1():
|
||||
def is_min_2_3_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.1.1"
|
||||
return torch_version >= "2.3.1"
|
||||
|
||||
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
0
tests/prompt_strategies/messages/__init__.py
Normal file
0
tests/prompt_strategies/messages/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user