Compare commits
46 Commits
upgrade-li
...
cj_tokeniz
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
28e7e444ee | ||
|
|
207e7627f9 | ||
|
|
7eb62ae5a9 | ||
|
|
95805cf850 | ||
|
|
4aafb7e600 | ||
|
|
17bc4c8b36 | ||
|
|
d101cfc125 | ||
|
|
e5cd55cff9 | ||
|
|
24aa6b15a0 | ||
|
|
9dfc5fa8b8 | ||
|
|
0c3255288f | ||
|
|
82b5dc9328 | ||
|
|
ec57918fcd | ||
|
|
dd87d8c438 | ||
|
|
ef942b6efc | ||
|
|
3c6a6c61be | ||
|
|
7b4b665e99 | ||
|
|
21326e4ef3 | ||
|
|
de23dab4fc | ||
|
|
e3efa29cf5 | ||
|
|
2038255052 | ||
|
|
dab2590e4d | ||
|
|
e5162b7a41 | ||
|
|
b6321d2220 | ||
|
|
6b3cdfdb8e | ||
|
|
203ae28704 | ||
|
|
ed3a33c9fb | ||
|
|
f61e2fc7dc | ||
|
|
b8056d04d9 | ||
|
|
88658c0570 | ||
|
|
260ca97f2c | ||
|
|
b1bb2accb9 | ||
|
|
efeaa00bb4 | ||
|
|
8a84408fc7 | ||
|
|
4805f3ca0a | ||
|
|
8ee30f5954 | ||
|
|
6ef76f1ace | ||
|
|
2e758aed6f | ||
|
|
21a2302538 | ||
|
|
89f382a13a | ||
|
|
eb188acbd4 | ||
|
|
34ea51dcf3 | ||
|
|
fd7538dca7 | ||
|
|
99b3bc7fbd | ||
|
|
4e38cea6b8 | ||
|
|
5edaad5b8b |
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -36,12 +36,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "124"
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -29,11 +29,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -91,11 +86,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
13
.github/workflows/multi-gpu-e2e.yml
vendored
13
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,17 +21,10 @@ jobs:
|
|||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
- cuda: 124
|
- cuda: 121
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.1.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
|
||||||
num_gpus: 2
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.0
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -28,11 +28,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -90,11 +85,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging
|
pip3 install wheel packaging
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Extract tag name
|
- name: Extract tag name
|
||||||
id: tag
|
id: tag
|
||||||
|
|||||||
15
.github/workflows/tests-nightly.yml
vendored
15
.github/workflows/tests-nightly.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
pytorch_version: ["2.3.1", "2.4.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -47,14 +47,13 @@ jobs:
|
|||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging
|
pip3 install --upgrade packaging
|
||||||
pip3 install -U -e .
|
pip3 install -U -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
@@ -82,17 +81,17 @@ jobs:
|
|||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: mamba-ssm
|
axolotl_extras: mamba-ssm
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 121
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.1.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.3.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras: mamba-ssm
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.4.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
76
.github/workflows/tests.yml
vendored
76
.github/workflows/tests.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
pytorch_version: ["2.3.1", "2.4.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -49,20 +49,16 @@ jobs:
|
|||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|
||||||
- name: upgrade pip
|
|
||||||
run: |
|
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }}
|
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 install --upgrade pip
|
||||||
|
pip3 install --upgrade packaging
|
||||||
pip3 install -U -e .
|
pip3 install -U -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
@@ -72,17 +68,29 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 60
|
||||||
needs: [pre-commit, pytest]
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.1
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -111,49 +119,3 @@ jobs:
|
|||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
|
||||||
runs-on: [self-hosted, modal]
|
|
||||||
timeout-minutes: 90
|
|
||||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
|
||||||
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.1
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.3.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras: mamba-ssm
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.0
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Install Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
- name: Install Modal
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install modal==0.63.64 jinja2
|
|
||||||
- name: Update env vars
|
|
||||||
run: |
|
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
|
||||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
|
||||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
|
||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
|
||||||
- name: Run tests job on Modal
|
|
||||||
run: |
|
|
||||||
modal run cicd.tests
|
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ Features:
|
|||||||
|
|
||||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||||
|
|
||||||
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
|
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
@@ -562,8 +562,7 @@ plugins:
|
|||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rope: true
|
liger_rope: true
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_glu_activation: true
|
liger_swiglu: true
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
git checkout FETCH_HEAD
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
@@ -37,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
RUN pip install -r requirements-tests.txt
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
@stub.function(
|
@stub.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=45 * 60,
|
||||||
cpu=8.0,
|
cpu=8.0,
|
||||||
memory=131072 * N_GPUS,
|
memory=131072 * N_GPUS,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
@stub.function(
|
@stub.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=45 * 60,
|
||||||
cpu=8.0,
|
cpu=8.0,
|
||||||
memory=131072,
|
memory=131072,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,6 +14,15 @@
|
|||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
"gradient_clipping": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
|
|||||||
@@ -24,6 +24,15 @@
|
|||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
"gradient_clipping": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
|
|||||||
@@ -20,6 +20,15 @@
|
|||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"gradient_clipping": "auto",
|
"gradient_clipping": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ load_in_8bit: true
|
|||||||
load_in_4bit: false
|
load_in_4bit: false
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: philschmid/guanaco-sharegpt-style
|
||||||
type: chat_template
|
type: sharegpt
|
||||||
shards: 10
|
shards: 10
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: temp_debug/axolotl_outputs/model
|
output_dir: temp_debug/axolotl_outputs/model
|
||||||
@@ -20,6 +20,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
|||||||
|
|
||||||
### Background
|
### Background
|
||||||
|
|
||||||
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
|
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
|
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
|
||||||
type: chat_template
|
type: sharegpt
|
||||||
```
|
```
|
||||||
|
|
||||||
>[!Important]
|
>[!Important]
|
||||||
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
|
|||||||
|
|
||||||
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
||||||
|
|
||||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||||
|
|
||||||
```jsonc
|
```jsonc
|
||||||
// .vscode/launch.json
|
// .vscode/launch.json
|
||||||
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
|||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Debug axolotl prompt - chat_template",
|
"name": "Debug axolotl prompt - sharegpt",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"module": "accelerate.commands.launch",
|
"module": "accelerate.commands.launch",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"args": [
|
"args": [
|
||||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
||||||
// The flags below simplify debugging by overriding the axolotl config
|
// The flags below simplify debugging by overriding the axolotl config
|
||||||
// with the debugging tips above. Modify as needed.
|
// with the debugging tips above. Modify as needed.
|
||||||
"--dataset_processes=1", // limits data preprocessing to one process
|
"--dataset_processes=1", // limits data preprocessing to one process
|
||||||
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
|||||||
</div>
|
</div>
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
|
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
|
||||||
|
|
||||||
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
||||||
|
|||||||
@@ -9,17 +9,14 @@ strict: false
|
|||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_glu_activation: true
|
liger_swiglu: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
chat_template: deepseek_v2
|
chat_template: deepseek_v2
|
||||||
datasets:
|
datasets:
|
||||||
- path: mlabonne/FineTome-100k
|
- path: mlabonne/FineTome-100k
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:20%]
|
split: train
|
||||||
field_messages: conversations
|
|
||||||
message_field_role: from
|
|
||||||
message_field_content: value
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
|
|||||||
@@ -11,11 +11,8 @@ chat_template: gemma
|
|||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
chat_template: gemma
|
||||||
drop_system_message: true
|
drop_system_message: true
|
||||||
field_messages: conversations
|
|
||||||
message_field_role: from
|
|
||||||
message_field_content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
|||||||
@@ -4,15 +4,11 @@ tokenizer_type: AutoTokenizer
|
|||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
strict: false
|
strict: false
|
||||||
use_tensorboard: true
|
use_tensorboard: true
|
||||||
chat_template: jamba
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
chat_template: jamba
|
||||||
drop_system_message: true
|
drop_system_message: true
|
||||||
field_messages: conversations
|
|
||||||
message_field_role: from
|
|
||||||
message_field_content: value
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: jamba-large-fsdp-qlora-ft
|
output_dir: jamba-large-fsdp-qlora-ft
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ plugins:
|
|||||||
- axolotl.integrations.liger.LigerPlugin
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
liger_rope: true
|
liger_rope: true
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
liger_glu_activation: true
|
liger_swiglu: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
strict: false
|
strict: false
|
||||||
@@ -14,10 +14,6 @@ datasets:
|
|||||||
- path: mlabonne/FineTome-100k
|
- path: mlabonne/FineTome-100k
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:20%]
|
split: train[:20%]
|
||||||
field_messages: conversations
|
|
||||||
message_field_role: from
|
|
||||||
message_field_content: value
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ rl: dpo
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
type: chat_template.default
|
type: chat_template.default
|
||||||
|
chat_template: llama3
|
||||||
field_messages: conversation
|
field_messages: conversation
|
||||||
field_chosen: chosen
|
field_chosen: chosen
|
||||||
field_rejected: rejected
|
field_rejected: rejected
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ chat_template: llama3
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
chat_template: llama3
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
message_field_role: role
|
message_field_role: role
|
||||||
message_field_content: content
|
message_field_content: content
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-3.2-1B
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: "<|end_of_text|>"
|
|
||||||
@@ -10,6 +10,7 @@ chat_template: phi_3
|
|||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
chat_template: phi_3
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
message_field_role: role
|
message_field_role: role
|
||||||
message_field_content: content
|
message_field_content: content
|
||||||
|
|||||||
@@ -2,4 +2,3 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
tbparse
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.2
|
transformers==4.45.2
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.1.0
|
accelerate==1.0.1
|
||||||
datasets==3.0.1
|
datasets==3.0.1
|
||||||
deepspeed==0.15.3
|
deepspeed==0.14.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
@@ -16,7 +16,7 @@ flash-attn==2.6.3
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers>=0.0.23.post1
|
xformers==0.0.28.post1
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -34,7 +34,7 @@ tensorboard
|
|||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.4.0
|
liger-kernel==0.3.0
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
|
|||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
|
trl==0.9.6
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
|||||||
12
setup.py
12
setup.py
@@ -31,8 +31,6 @@ def parse_requirements():
|
|||||||
try:
|
try:
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
|
||||||
|
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# don't install xformers on MacOS
|
# don't install xformers on MacOS
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
@@ -52,16 +50,10 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 5):
|
if (major, minor) >= (2, 4):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
|
||||||
elif (major, minor) >= (2, 4):
|
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.27")
|
_install_requires.append("xformers>=0.0.27")
|
||||||
else:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers==0.0.28.post1")
|
|
||||||
elif (major, minor) >= (2, 3):
|
elif (major, minor) >= (2, 3):
|
||||||
_install_requires.pop(_install_requires.index(torchao_version))
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -81,6 +73,7 @@ def parse_requirements():
|
|||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|
||||||
@@ -109,7 +102,6 @@ setup(
|
|||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
"mamba-ssm==1.2.0.post1",
|
"mamba-ssm==1.2.0.post1",
|
||||||
"causal_conv1d",
|
|
||||||
],
|
],
|
||||||
"auto-gptq": [
|
"auto-gptq": [
|
||||||
"auto-gptq==0.5.1",
|
"auto-gptq==0.5.1",
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
chat_template_str = get_chat_template(cfg.chat_template)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
@@ -462,12 +462,7 @@ def load_datasets(
|
|||||||
processor=processor,
|
processor=processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if cli_args.debug or cfg.debug:
|
||||||
cli_args.debug
|
|
||||||
or cfg.debug
|
|
||||||
or cli_args.debug_text_only
|
|
||||||
or int(cli_args.debug_num_examples) > 0
|
|
||||||
):
|
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
train_dataset.select(
|
train_dataset.select(
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class TrainerCliArgs:
|
|||||||
|
|
||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
debug_text_only: bool = field(default=False)
|
debug_text_only: bool = field(default=False)
|
||||||
debug_num_examples: int = field(default=0)
|
debug_num_examples: int = field(default=5)
|
||||||
inference: bool = field(default=False)
|
inference: bool = field(default=False)
|
||||||
merge_lora: bool = field(default=False)
|
merge_lora: bool = field(default=False)
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import abc
|
|||||||
import gc
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -28,6 +27,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
|
PreTrainedModel,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -48,7 +48,6 @@ from trl import (
|
|||||||
)
|
)
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -667,9 +666,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
|
||||||
):
|
|
||||||
# use one's weighted cross entropy loss calc
|
# use one's weighted cross entropy loss calc
|
||||||
# if self.args.sample_packing:
|
# if self.args.sample_packing:
|
||||||
# labels = inputs.pop("labels")
|
# labels = inputs.pop("labels")
|
||||||
@@ -677,18 +674,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
model,
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
inputs,
|
|
||||||
return_outputs=return_outputs,
|
|
||||||
num_items_in_batch=num_items_in_batch,
|
|
||||||
)
|
|
||||||
return super().compute_loss(
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=return_outputs,
|
|
||||||
num_items_in_batch=num_items_in_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||||
@@ -784,13 +771,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
).squeeze(2)
|
).squeeze(2)
|
||||||
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
||||||
|
|
||||||
def orpo_compute_loss(
|
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
||||||
self,
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=False,
|
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||||
inputs,
|
inputs,
|
||||||
label_pad_token=-100,
|
label_pad_token=-100,
|
||||||
@@ -896,13 +877,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, **kwargs):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# make sure the checkpoint dir exists, since trainer is flakey
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
run_dir = self._get_output_dir(trial=trial)
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
@@ -917,7 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
|
|||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
|
||||||
):
|
):
|
||||||
input_ids = inputs.pop("input_ids")
|
input_ids = inputs.pop("input_ids")
|
||||||
lm_logits = model(input_ids).logits
|
lm_logits = model(input_ids).logits
|
||||||
@@ -1025,32 +1005,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
self,
|
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
||||||
features,
|
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
res = super().tokenize_row(
|
res = super().tokenize_row(feature, model=model)
|
||||||
features,
|
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
)
|
|
||||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
|
||||||
for key in res.keys():
|
for key in res.keys():
|
||||||
res[key] = res[key][1:]
|
res[key] = res[key][1:]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
self,
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
|
||||||
model: nn.Module,
|
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
||||||
num_items_in_batch=None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
loss: torch.Tensor = super().training_step(model, inputs)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
@@ -1148,28 +1114,17 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
def get_callbacks(self) -> List[TrainerCallback]:
|
def get_callbacks(self) -> List[TrainerCallback]:
|
||||||
callbacks = []
|
callbacks = []
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
callbacks.extend(
|
|
||||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
from transformers.integrations.integration_utils import MLflowCallback
|
|
||||||
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks.extend(
|
callbacks.append(
|
||||||
[
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
|
||||||
MLflowCallback,
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
if self.cfg.use_comet and is_comet_available():
|
if self.cfg.use_comet and is_comet_available():
|
||||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||||
@@ -1180,17 +1135,11 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
"""
|
"""
|
||||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||||
"""
|
"""
|
||||||
callbacks = []
|
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
callbacks.extend(
|
|
||||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
|
||||||
)
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||||
# TODO
|
# TODO
|
||||||
@@ -1236,7 +1185,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
callbacks = []
|
||||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||||
LogPredictionCallback = log_prediction_callback_factory(
|
LogPredictionCallback = log_prediction_callback_factory(
|
||||||
trainer, self.tokenizer, "wandb"
|
trainer, self.tokenizer, "wandb"
|
||||||
@@ -1608,8 +1557,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
if self.cfg.chat_template:
|
if self.cfg.chat_template:
|
||||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||||
self.cfg.chat_template,
|
self.cfg.chat_template
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
@@ -1714,17 +1662,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
)
|
)
|
||||||
sig = inspect.signature(trainer_cls)
|
|
||||||
if "processing_class" in sig.parameters.keys():
|
|
||||||
trainer_kwargs["processing_class"] = self.tokenizer
|
|
||||||
else:
|
|
||||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
|
||||||
|
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
@@ -1765,8 +1708,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
]
|
]
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
collator = RewardDataCollatorWithPadding
|
collator = RewardDataCollatorWithPadding
|
||||||
if "max_length" in kwargs:
|
|
||||||
kwargs.pop("max_length")
|
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
@@ -1804,7 +1745,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
callbacks = []
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def build_training_arguments(self, total_num_steps):
|
def build_training_arguments(self, total_num_steps):
|
||||||
@@ -1969,7 +1910,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
@@ -1981,17 +1922,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
|
||||||
if "processing_class" in sig.parameters.keys():
|
|
||||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
|
||||||
else:
|
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
|
||||||
|
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**dpo_trainer_kwargs,
|
**dpo_trainer_kwargs,
|
||||||
)
|
)
|
||||||
@@ -2013,11 +1948,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = []
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
callbacks = []
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ Plugins can be used to integrate third-party models, modify the training process
|
|||||||
|
|
||||||
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
||||||
"""
|
"""
|
||||||
import collections
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from typing import OrderedDict
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
@@ -48,7 +47,7 @@ class BasePlugin:
|
|||||||
Initializes the BasePlugin.
|
Initializes the BasePlugin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register(self, cfg): # pylint: disable=unused-argument
|
def register(self, cfg):
|
||||||
"""
|
"""
|
||||||
Registers the plugin with the given configuration.
|
Registers the plugin with the given configuration.
|
||||||
|
|
||||||
@@ -64,7 +63,7 @@ class BasePlugin:
|
|||||||
Returns a pydantic model for the plugin's input arguments.
|
Returns a pydantic model for the plugin's input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
def pre_model_load(self, cfg):
|
||||||
"""
|
"""
|
||||||
Performs actions before the model is loaded.
|
Performs actions before the model is loaded.
|
||||||
|
|
||||||
@@ -75,7 +74,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
def post_model_load(self, cfg, model):
|
||||||
"""
|
"""
|
||||||
Performs actions after the model is loaded.
|
Performs actions after the model is loaded.
|
||||||
|
|
||||||
@@ -87,7 +86,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
def pre_lora_load(self, cfg, model):
|
||||||
"""
|
"""
|
||||||
Performs actions before LoRA weights are loaded.
|
Performs actions before LoRA weights are loaded.
|
||||||
|
|
||||||
@@ -99,7 +98,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
def post_lora_load(self, cfg, model):
|
||||||
"""
|
"""
|
||||||
Performs actions after LoRA weights are loaded.
|
Performs actions after LoRA weights are loaded.
|
||||||
|
|
||||||
@@ -111,7 +110,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
def create_optimizer(self, cfg, trainer):
|
||||||
"""
|
"""
|
||||||
Creates and returns an optimizer for training.
|
Creates and returns an optimizer for training.
|
||||||
|
|
||||||
@@ -123,9 +122,7 @@ class BasePlugin:
|
|||||||
object: The created optimizer.
|
object: The created optimizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_lr_scheduler(
|
def create_lr_scheduler(self, cfg, trainer, optimizer):
|
||||||
self, cfg, trainer, optimizer
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
"""
|
"""
|
||||||
Creates and returns a learning rate scheduler.
|
Creates and returns a learning rate scheduler.
|
||||||
|
|
||||||
@@ -138,7 +135,7 @@ class BasePlugin:
|
|||||||
object: The created learning rate scheduler.
|
object: The created learning rate scheduler.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
def add_callbacks_pre_trainer(self, cfg, model):
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer before training.
|
Adds callbacks to the trainer before training.
|
||||||
|
|
||||||
@@ -149,11 +146,8 @@ class BasePlugin:
|
|||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||||
"""
|
"""
|
||||||
return []
|
|
||||||
|
|
||||||
def add_callbacks_post_trainer(
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
self, cfg, trainer
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer after training.
|
Adds callbacks to the trainer after training.
|
||||||
|
|
||||||
@@ -164,9 +158,8 @@ class BasePlugin:
|
|||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||||
"""
|
"""
|
||||||
return []
|
|
||||||
|
|
||||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
def post_train(self, cfg, model):
|
||||||
"""
|
"""
|
||||||
Performs actions after training is complete.
|
Performs actions after training is complete.
|
||||||
|
|
||||||
@@ -178,7 +171,7 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
def post_train_unload(self, cfg):
|
||||||
"""
|
"""
|
||||||
Performs actions after training is complete and the model is unloaded.
|
Performs actions after training is complete and the model is unloaded.
|
||||||
|
|
||||||
@@ -234,7 +227,7 @@ class PluginManager:
|
|||||||
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
|
plugins: List[BasePlugin] = []
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
@@ -244,7 +237,7 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super(PluginManager, cls).__new__(cls)
|
cls._instance = super(PluginManager, cls).__new__(cls)
|
||||||
cls._instance.plugins = collections.OrderedDict()
|
cls._instance.plugins: List[BasePlugin] = []
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -272,7 +265,7 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
plugin = load_plugin(plugin_name)
|
plugin = load_plugin(plugin_name)
|
||||||
self.plugins[plugin_name] = plugin
|
self.plugins.append(plugin)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||||
|
|
||||||
@@ -284,7 +277,7 @@ class PluginManager:
|
|||||||
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
||||||
"""
|
"""
|
||||||
input_args = []
|
input_args = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
input_args_from_plugin = plugin.get_input_args()
|
input_args_from_plugin = plugin.get_input_args()
|
||||||
if input_args_from_plugin is not None:
|
if input_args_from_plugin is not None:
|
||||||
input_args.append(input_args_from_plugin)
|
input_args.append(input_args_from_plugin)
|
||||||
@@ -300,7 +293,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
plugin.pre_model_load(cfg)
|
plugin.pre_model_load(cfg)
|
||||||
|
|
||||||
def post_model_load(self, cfg, model):
|
def post_model_load(self, cfg, model):
|
||||||
@@ -314,7 +307,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
plugin.post_model_load(cfg, model)
|
plugin.post_model_load(cfg, model)
|
||||||
|
|
||||||
def pre_lora_load(self, cfg, model):
|
def pre_lora_load(self, cfg, model):
|
||||||
@@ -328,7 +321,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
plugin.pre_lora_load(cfg, model)
|
plugin.pre_lora_load(cfg, model)
|
||||||
|
|
||||||
def post_lora_load(self, cfg, model):
|
def post_lora_load(self, cfg, model):
|
||||||
@@ -342,7 +335,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
plugin.post_lora_load(cfg, model)
|
plugin.post_lora_load(cfg, model)
|
||||||
|
|
||||||
def create_optimizer(self, cfg, trainer):
|
def create_optimizer(self, cfg, trainer):
|
||||||
@@ -356,7 +349,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
object: The created optimizer, or None if none was found.
|
object: The created optimizer, or None if none was found.
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
optimizer = plugin.create_optimizer(cfg, trainer)
|
optimizer = plugin.create_optimizer(cfg, trainer)
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
return optimizer
|
return optimizer
|
||||||
@@ -374,7 +367,7 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
object: The created learning rate scheduler, or None if none was found.
|
object: The created learning rate scheduler, or None if none was found.
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
return scheduler
|
return scheduler
|
||||||
@@ -392,7 +385,7 @@ class PluginManager:
|
|||||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -408,7 +401,7 @@ class PluginManager:
|
|||||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -423,5 +416,5 @@ class PluginManager:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins:
|
||||||
plugin.post_train_unload(cfg)
|
plugin.post_train_unload(cfg)
|
||||||
|
|||||||
@@ -18,23 +18,20 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
|||||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
from ...utils.distributed import zero_only
|
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
"""
|
"""
|
||||||
@@ -45,31 +42,59 @@ class LigerPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.liger.LigerArgs"
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
if cfg.model_config_type == "llama":
|
||||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
from liger_kernel.transformers.model.llama import (
|
||||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
lce_forward as llama_lce_forward,
|
||||||
kwargs = {}
|
)
|
||||||
if "rope" in liger_fn_sig.parameters:
|
from transformers.models.llama import modeling_llama
|
||||||
kwargs["rope"] = cfg.liger_rope
|
|
||||||
if "cross_entropy" in liger_fn_sig.parameters:
|
if cfg.liger_rope:
|
||||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
if cfg.liger_rms_norm:
|
||||||
kwargs[
|
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
||||||
"fused_linear_cross_entropy"
|
if cfg.liger_swiglu:
|
||||||
] = cfg.liger_fused_linear_cross_entropy
|
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
||||||
if "rms_norm" in liger_fn_sig.parameters:
|
if cfg.liger_cross_entropy:
|
||||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if "layer_norm" in liger_fn_sig.parameters:
|
elif cfg.liger_fused_linear_cross_entropy:
|
||||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
||||||
if "geglu" in liger_fn_sig.parameters:
|
|
||||||
kwargs["geglu"] = cfg.liger_glu_activation
|
elif cfg.model_config_type == "mistral":
|
||||||
elif "swiglu" in liger_fn_sig.parameters:
|
from liger_kernel.transformers.model.mistral import (
|
||||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
lce_forward as mistral_lce_forward,
|
||||||
with zero_only():
|
)
|
||||||
LOG.info(
|
from transformers.models.mistral import modeling_mistral
|
||||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_swiglu:
|
||||||
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
||||||
|
|
||||||
|
elif cfg.model_config_type == "gemma":
|
||||||
|
from liger_kernel.transformers.model.gemma import (
|
||||||
|
lce_forward as gemma_lce_forward,
|
||||||
|
)
|
||||||
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_gemma.GemmaRMSNorm = partial(
|
||||||
|
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||||
)
|
)
|
||||||
apply_liger_fn(**kwargs)
|
if cfg.liger_swiglu:
|
||||||
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
||||||
|
|
||||||
elif cfg.model_config_type == "jamba":
|
elif cfg.model_config_type == "jamba":
|
||||||
from transformers.models.jamba import modeling_jamba
|
from transformers.models.jamba import modeling_jamba
|
||||||
|
|
||||||
@@ -79,12 +104,30 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_swiglu:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
|
|
||||||
|
elif cfg.model_config_type == "qwen2":
|
||||||
|
from liger_kernel.transformers.model.qwen2 import (
|
||||||
|
lce_forward as qwen2_lce_forward,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen2 import modeling_qwen2
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_swiglu:
|
||||||
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
||||||
|
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
@@ -103,9 +146,44 @@ class LigerPlugin(BasePlugin):
|
|||||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_swiglu:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
|
||||||
|
elif cfg.model_config_type == "gemma2":
|
||||||
|
from transformers.models.gemma2 import modeling_gemma2
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_gemma2.Gemma2RMSNorm = partial(
|
||||||
|
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||||
|
)
|
||||||
|
if cfg.liger_swiglu:
|
||||||
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
logging.warning(
|
||||||
|
"Fused linear cross entropy is not supported for Gemma 2."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif cfg.model_config_type == "phi3":
|
||||||
|
from liger_kernel.transformers.model.phi3 import (
|
||||||
|
lce_forward as phi3_lce_forward,
|
||||||
|
)
|
||||||
|
from transformers.models.phi3 import modeling_phi3
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_swiglu:
|
||||||
|
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
||||||
|
|||||||
@@ -15,12 +15,9 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
|
||||||
|
|
||||||
|
|
||||||
class LigerArgs(BaseModel):
|
class LigerArgs(BaseModel):
|
||||||
@@ -30,24 +27,6 @@ class LigerArgs(BaseModel):
|
|||||||
|
|
||||||
liger_rope: Optional[bool] = None
|
liger_rope: Optional[bool] = None
|
||||||
liger_rms_norm: Optional[bool] = None
|
liger_rms_norm: Optional[bool] = None
|
||||||
liger_layer_norm: Optional[bool] = None
|
|
||||||
liger_swiglu: Optional[bool] = None
|
liger_swiglu: Optional[bool] = None
|
||||||
liger_glu_activation: Optional[bool] = None
|
|
||||||
liger_cross_entropy: Optional[bool] = None
|
liger_cross_entropy: Optional[bool] = None
|
||||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_deprecated_swiglu(cls, data):
|
|
||||||
if data.get("liger_swiglu") is not None:
|
|
||||||
if data.get("liger_glu_activation") is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
|
||||||
"Please use 'liger_glu_activation' instead."
|
|
||||||
)
|
|
||||||
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
|
||||||
return data
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
|
|
||||||
@@ -43,19 +44,7 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def is_xformers_available() -> bool:
|
|
||||||
try:
|
|
||||||
import xformers # pylint: disable=unused-import # noqa: F401
|
|
||||||
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_xformers_swiglu_available() -> bool:
|
def is_xformers_swiglu_available() -> bool:
|
||||||
if not is_xformers_available():
|
|
||||||
return False
|
|
||||||
|
|
||||||
from xformers.ops.common import get_xformers_operator
|
from xformers.ops.common import get_xformers_operator
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -68,11 +57,6 @@ def is_xformers_swiglu_available() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def replace_llama_mlp_with_swiglu(model):
|
def replace_llama_mlp_with_swiglu(model):
|
||||||
if is_xformers_swiglu_available():
|
|
||||||
from axolotl.monkeypatch.xformers_ import FusedMLP
|
|
||||||
else:
|
|
||||||
raise RuntimeError("xformers SwiGLU not available for this environment")
|
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, LlamaMLP):
|
if isinstance(module, LlamaMLP):
|
||||||
mlp = FusedMLP(
|
mlp = FusedMLP(
|
||||||
@@ -197,6 +181,49 @@ class FusedAttention(LlamaAttention):
|
|||||||
set_module_name(model, name, new_attn)
|
set_module_name(model, name, new_attn)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMLP(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
gate_proj: torch.nn.Linear,
|
||||||
|
up_proj: torch.nn.Linear,
|
||||||
|
down_proj: torch.nn.Linear,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.swiglu = SwiGLU(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
hidden_features=config.intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
_pack_weights=True,
|
||||||
|
)
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.swiglu.w12.weight.data = torch.cat(
|
||||||
|
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||||
|
)
|
||||||
|
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||||
|
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the split weights back to the original layers
|
||||||
|
new_mlp = LlamaMLP(self.config)
|
||||||
|
new_mlp.gate_proj.weight.data = w1
|
||||||
|
new_mlp.up_proj.weight.data = w2
|
||||||
|
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_mlp)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||||
|
return self.swiglu(x)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
|
|||||||
@@ -16,6 +16,26 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
|
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
loss = fast_cross_entropy_loss(
|
||||||
|
logits = shift_logits,
|
||||||
|
labels = shift_labels,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
ORIGINAL_QKV_CODE = """
|
ORIGINAL_QKV_CODE = """
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
@@ -60,6 +80,12 @@ def get_forward_code() -> str:
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_cel_is_patchable() -> bool:
|
||||||
|
forward = get_forward_code()
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
return ORIGINAL_CEL_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
def get_self_attn_code() -> str:
|
def get_self_attn_code() -> str:
|
||||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||||
return forward
|
return forward
|
||||||
@@ -72,31 +98,48 @@ def check_self_attn_is_patchable() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||||
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
|
|
||||||
|
|
||||||
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
|
|
||||||
logits,
|
|
||||||
labels,
|
|
||||||
vocab_size: int, # pylint: disable=unused-argument
|
|
||||||
num_items_in_batch: int = None,
|
|
||||||
ignore_index: int = -100, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
|
|
||||||
loss = fast_cross_entropy_loss(
|
|
||||||
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
|
|
||||||
)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
if model_type == "llama":
|
if model_type == "llama":
|
||||||
from transformers.loss import loss_utils
|
forward = get_forward_code()
|
||||||
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
|
forward = forward.replace(
|
||||||
|
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||||
|
)
|
||||||
|
forward = forward.replace(
|
||||||
|
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||||
|
forward = forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def fast_cross_entropy_loss_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||||
|
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported model type")
|
raise ValueError("Unsupported model type")
|
||||||
|
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
"""
|
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
|
||||||
from xformers.ops import SwiGLU
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import set_module_name
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMLP(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
gate_proj: torch.nn.Linear,
|
|
||||||
up_proj: torch.nn.Linear,
|
|
||||||
down_proj: torch.nn.Linear,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.swiglu = SwiGLU(
|
|
||||||
in_features=config.hidden_size,
|
|
||||||
hidden_features=config.intermediate_size,
|
|
||||||
bias=False,
|
|
||||||
_pack_weights=True,
|
|
||||||
)
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.swiglu.w12.weight.data = torch.cat(
|
|
||||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
|
||||||
)
|
|
||||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
|
||||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign the split weights back to the original layers
|
|
||||||
new_mlp = LlamaMLP(self.config)
|
|
||||||
new_mlp.gate_proj.weight.data = w1
|
|
||||||
new_mlp.up_proj.weight.data = w2
|
|
||||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_mlp)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
|
||||||
return self.swiglu(x)
|
|
||||||
@@ -260,10 +260,8 @@ def train(
|
|||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
try:
|
try:
|
||||||
trainer.create_model_card(
|
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||||
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
except AttributeError:
|
||||||
)
|
|
||||||
except (AttributeError, UnicodeDecodeError):
|
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# defensively push to the hub to ensure the model card is updated
|
# defensively push to the hub to ensure the model card is updated
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -57,7 +57,6 @@ class ChatTemplate(str, Enum):
|
|||||||
jinja = "jinja" # pylint: disable=invalid-name
|
jinja = "jinja" # pylint: disable=invalid-name
|
||||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
@@ -584,7 +583,6 @@ class AxolotlInputConfig(
|
|||||||
resume_from_checkpoint: Optional[str] = None
|
resume_from_checkpoint: Optional[str] = None
|
||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
mean_resizing_embeddings: Optional[bool] = False
|
|
||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
|
|||||||
@@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import requests
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
@@ -55,28 +53,6 @@ from axolotl.utils.trainer import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
|
||||||
def decorator(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except (
|
|
||||||
requests.exceptions.ReadTimeout,
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
) as exc:
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
time.sleep(delay)
|
|
||||||
else:
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
|
||||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
|
|||||||
@@ -16,7 +16,3 @@ def setup_mlflow_env_vars(cfg: DictDefault):
|
|||||||
# Enable mlflow if experiment name is present
|
# Enable mlflow if experiment name is present
|
||||||
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
|
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
|
||||||
cfg.use_mlflow = True
|
cfg.use_mlflow = True
|
||||||
|
|
||||||
# Enable logging hf artifacts in mlflow if value is truthy
|
|
||||||
if cfg.hf_mlflow_log_artifacts is True:
|
|
||||||
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -133,8 +133,6 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
|
|
||||||
self.len_across_ranks = None
|
|
||||||
|
|
||||||
def set_epoch(self, epoch: int):
|
def set_epoch(self, epoch: int):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
|
||||||
@@ -197,14 +195,15 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||||
return math.floor(0.998 * min(estimates))
|
return math.floor(0.998 * min(estimates))
|
||||||
|
|
||||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
min_len_batches = reduce_and_broadcast(
|
||||||
|
lambda: num,
|
||||||
|
calc_min_len,
|
||||||
|
)
|
||||||
return min_len_batches
|
return min_len_batches
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if not self.len_across_ranks:
|
len_batches = self.num_batches()
|
||||||
len_batches = self.num_batches()
|
return self.gather_len_batches(len_batches)
|
||||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
|
||||||
return self.len_across_ranks
|
|
||||||
|
|
||||||
def _len_est(self):
|
def _len_est(self):
|
||||||
efficiency = (
|
efficiency = (
|
||||||
|
|||||||
76
test.yml
76
test.yml
@@ -1,76 +0,0 @@
|
|||||||
base_model: JackFram/llama-68m
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.5
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 1024
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 100
|
|
||||||
evals_per_epoch: 2
|
|
||||||
eval_table_size:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
|
|
||||||
fsdp:
|
|
||||||
- full_shard
|
|
||||||
- auto_wrap
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_limit_all_gathers: true
|
|
||||||
fsdp_sync_module_states: true
|
|
||||||
fsdp_offload_params: true
|
|
||||||
fsdp_use_orig_params: false
|
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
|
||||||
fsdp_backward_prefetch: BACKWARD_PRE
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|finetune_right_pad_id|>
|
|
||||||
eos_token: <|eot_id|>
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Simple end-to-end test for Liger integration
|
Simple end-to-end test for Liger integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -63,51 +64,6 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_llama_wo_flce2(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"plugins": [
|
|
||||||
"axolotl.integrations.liger.LigerPlugin",
|
|
||||||
],
|
|
||||||
"liger_rope": True,
|
|
||||||
"liger_rms_norm": True,
|
|
||||||
"liger_swiglu": True,
|
|
||||||
"liger_cross_entropy": True,
|
|
||||||
"liger_fused_linear_cross_entropy": False,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_llama_w_flce(self, temp_dir):
|
def test_llama_w_flce(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -1,155 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for multigpu eval
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUEval(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for MultiGPU Eval Sample Packing
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval_sample_packing(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"load_in_8bit": False,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"strict": False,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"sample_packing": True,
|
|
||||||
"eval_sample_packing": True,
|
|
||||||
"pad_to_sequence_len": True,
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"loss_watchdog_threshold": 5.0,
|
|
||||||
"loss_watchdog_patience": 3,
|
|
||||||
"bf16": "auto",
|
|
||||||
"warmup_steps": 1,
|
|
||||||
"evals_per_epoch": 2,
|
|
||||||
"eval_max_new_tokens": 128,
|
|
||||||
"saves_per_epoch": 1,
|
|
||||||
"logging_steps": 1,
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"accelerate",
|
|
||||||
"launch",
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"load_in_8bit": False,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"strict": False,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"sample_packing": True,
|
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"loss_watchdog_threshold": 5.0,
|
|
||||||
"loss_watchdog_patience": 3,
|
|
||||||
"bf16": "auto",
|
|
||||||
"warmup_steps": 1,
|
|
||||||
"evals_per_epoch": 2,
|
|
||||||
"eval_max_new_tokens": 128,
|
|
||||||
"saves_per_epoch": 1,
|
|
||||||
"logging_steps": 1,
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"accelerate",
|
|
||||||
"launch",
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
@@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import is_hopper, with_temp_dir
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -59,7 +59,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 100,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -116,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 50,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -144,146 +144,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"sample_packing": False,
|
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.05,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"rl": "dpo",
|
|
||||||
"chat_template": "llama3",
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
|
||||||
"type": "chat_template.default",
|
|
||||||
"field_messages": "conversation",
|
|
||||||
"field_chosen": "chosen",
|
|
||||||
"field_rejected": "rejected",
|
|
||||||
"message_field_role": "role",
|
|
||||||
"message_field_content": "content",
|
|
||||||
"roles": {
|
|
||||||
"system": ["system"],
|
|
||||||
"user": ["user"],
|
|
||||||
"assistant": ["assistant"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 15,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"warmup_steps": 0,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"accelerate",
|
|
||||||
"launch",
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_qlora_ddp(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"sample_packing": False,
|
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.05,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"rl": "dpo",
|
|
||||||
"chat_template": "chatml",
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
|
||||||
"type": "chat_template.default",
|
|
||||||
"field_messages": "conversation",
|
|
||||||
"field_chosen": "chosen",
|
|
||||||
"field_rejected": "rejected",
|
|
||||||
"message_field_role": "role",
|
|
||||||
"message_field_content": "content",
|
|
||||||
"roles": {
|
|
||||||
"system": ["system"],
|
|
||||||
"user": ["user"],
|
|
||||||
"assistant": ["assistant"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 15,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"warmup_steps": 0,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"accelerate",
|
|
||||||
"launch",
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fsdp(self, temp_dir):
|
def test_fsdp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -305,7 +165,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 100,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -371,7 +231,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 100,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -413,6 +273,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip("disabled due to upstream issue")
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -421,7 +282,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"mean_resizing_embeddings": True,
|
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
@@ -437,7 +297,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "</s>",
|
"pad_token": "<|end_of_text|>",
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -447,7 +307,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 100,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -513,7 +373,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 100,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -572,7 +432,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 100,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 100,
|
||||||
"warmup_steps": 20,
|
"warmup_steps": 20,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import require_torch_2_3_1, with_temp_dir
|
from ..utils import require_torch_2_1_1, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
Test case for Llama models using 4d attention with multipack
|
Test case for Llama models using 4d attention with multipack
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@require_torch_2_3_1
|
@require_torch_2_1_1
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_sdp_lora_packing(self, temp_dir):
|
def test_sdp_lora_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
@@ -1,12 +1,22 @@
|
|||||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
from axolotl.monkeypatch.unsloth_ import (
|
||||||
|
check_cel_is_patchable,
|
||||||
|
check_self_attn_is_patchable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestUnslothIntegration(unittest.TestCase):
|
class TestUnslothIntegration(unittest.TestCase):
|
||||||
"""Unsloth monkeypatch integration tests."""
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_is_cel_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_cel_is_patchable(),
|
||||||
|
"HF transformers loss code has changed and isn't patchable",
|
||||||
|
)
|
||||||
|
|
||||||
def test_is_self_attn_patchable(self):
|
def test_is_self_attn_patchable(self):
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
"""Module for testing ModelLoader."""
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="temp_dir")
|
|
||||||
def fixture_temp_dir():
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
yield temp_dir
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoadModelUtils:
|
|
||||||
"""
|
|
||||||
Testing module testing ModelLoader.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setup_method(self):
|
|
||||||
# load config
|
|
||||||
self.cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"tokenizer_config": "JackFram/llama-68m",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": False,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
ModelLoader(
|
|
||||||
cfg=self.cfg,
|
|
||||||
tokenizer="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
|
|
||||||
def test_convert_embedding_modules_dtype(
|
|
||||||
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
|
||||||
):
|
|
||||||
self.cfg.output_dir = temp_dir
|
|
||||||
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
|
||||||
self.model_loader.model, _ = load_model(
|
|
||||||
self.cfg,
|
|
||||||
self.model_loader.tokenizer,
|
|
||||||
inference=False,
|
|
||||||
reference_model=True,
|
|
||||||
)
|
|
||||||
self.model_loader.convert_embedding_modules_dtype(
|
|
||||||
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
|
||||||
)
|
|
||||||
for name, module in self.model_loader.model.named_modules():
|
|
||||||
if (
|
|
||||||
"norm" in name
|
|
||||||
or (before_kbit_train_or_finetune and name.endswith(".gate"))
|
|
||||||
or (
|
|
||||||
any(m in name for m in embedding_modules)
|
|
||||||
and hasattr(module, "weight")
|
|
||||||
)
|
|
||||||
):
|
|
||||||
for _, param in module.named_parameters():
|
|
||||||
assert param.dtype == dist_dtype
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for packed training
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from tbparse import SummaryReader
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import most_recent_subdir, with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPackedLlama(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for Packed training of llama models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_loss_packed(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
|
||||||
reader = SummaryReader(event_file)
|
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
@@ -9,8 +9,6 @@ from functools import wraps
|
|||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@wraps(test_func)
|
@wraps(test_func)
|
||||||
@@ -37,18 +35,13 @@ def most_recent_subdir(path):
|
|||||||
return subdir
|
return subdir
|
||||||
|
|
||||||
|
|
||||||
def require_torch_2_3_1(test_case):
|
def require_torch_2_1_1(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires torch >= 2.3.1
|
Decorator marking a test that requires torch >= 2.1.1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_min_2_3_1():
|
def is_min_2_1_1():
|
||||||
torch_version = version("torch")
|
torch_version = version("torch")
|
||||||
return torch_version >= "2.3.1"
|
return torch_version >= "2.1.1"
|
||||||
|
|
||||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_hopper():
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
|
||||||
return compute_capability == (9, 0)
|
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
"""
|
|
||||||
config validation tests for swiglu args
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_base_cfg")
|
|
||||||
def fixture_cfg():
|
|
||||||
return DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
|
||||||
"learning_rate": 0.000001,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseValidation:
|
|
||||||
"""
|
|
||||||
Base validation module to setup the log capture
|
|
||||||
"""
|
|
||||||
|
|
||||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def inject_fixtures(self, caplog):
|
|
||||||
self._caplog = caplog
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods
|
|
||||||
class TestValidation(BaseValidation):
|
|
||||||
"""
|
|
||||||
Test the validation module for liger
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_deprecated_swiglu(self, minimal_cfg):
|
|
||||||
test_cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"liger_swiglu": False,
|
|
||||||
}
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
updated_cfg = validate_config(test_cfg)
|
|
||||||
assert (
|
|
||||||
"The 'liger_swiglu' argument is deprecated"
|
|
||||||
in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
assert updated_cfg.liger_swiglu is None
|
|
||||||
assert updated_cfg.liger_glu_activations is False
|
|
||||||
|
|
||||||
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
|
||||||
test_cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"liger_swiglu": False,
|
|
||||||
"liger_glu_activations": True,
|
|
||||||
}
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
|
||||||
):
|
|
||||||
validate_config(test_cfg)
|
|
||||||
@@ -306,10 +306,6 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
"""Verify that processing data from the hub works with a specific revision"""
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
|
||||||
# make sure prepared_path is empty
|
|
||||||
shutil.rmtree(prepared_path, ignore_errors=True)
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
@@ -371,44 +367,43 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
def test_load_local_hub_with_revision(self):
|
def test_load_local_hub_with_revision(self):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir2:
|
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||||
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
snapshot_download(
|
||||||
snapshot_download(
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_type="dataset",
|
||||||
repo_type="dataset",
|
local_dir=tmp_ds_path,
|
||||||
local_dir=tmp_ds_path,
|
revision="d05c1cb",
|
||||||
revision="d05c1cb",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
"ds_type": "parquet",
|
"ds_type": "parquet",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"data_files": [
|
"data_files": [
|
||||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||||
],
|
],
|
||||||
"revision": "d05c1cb",
|
"revision": "d05c1cb",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
self.tokenizer, cfg, prepared_path
|
self.tokenizer, cfg, prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from axolotl.utils import is_comet_available
|
|||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
|
||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -1433,58 +1432,3 @@ class TestValidationComet(BaseValidation):
|
|||||||
|
|
||||||
for key in comet_env.keys():
|
for key in comet_env.keys():
|
||||||
os.environ.pop(key, None)
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
class TestValidationMLflow(BaseValidation):
|
|
||||||
"""
|
|
||||||
Validation test for MLflow
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
{
|
|
||||||
"hf_mlflow_log_artifacts": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
|
|
||||||
assert new_cfg.hf_mlflow_log_artifacts is True
|
|
||||||
|
|
||||||
# Check it's not already present in env
|
|
||||||
assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ
|
|
||||||
|
|
||||||
setup_mlflow_env_vars(new_cfg)
|
|
||||||
|
|
||||||
assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true"
|
|
||||||
|
|
||||||
os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None)
|
|
||||||
|
|
||||||
def test_mlflow_not_used_by_default(self, minimal_cfg):
|
|
||||||
cfg = DictDefault({}) | minimal_cfg
|
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
|
|
||||||
setup_mlflow_env_vars(new_cfg)
|
|
||||||
|
|
||||||
assert cfg.use_mlflow is not True
|
|
||||||
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
{
|
|
||||||
"mlflow_experiment_name": "foo",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
|
|
||||||
setup_mlflow_env_vars(new_cfg)
|
|
||||||
|
|
||||||
assert new_cfg.use_mlflow is True
|
|
||||||
|
|
||||||
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
|
|
||||||
|
|||||||
@@ -1,64 +1,18 @@
|
|||||||
"""Module for testing models utils file."""
|
"""Module for testing models utils file."""
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
|
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
||||||
from transformers.utils.import_utils import is_torch_mps_available
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import ModelLoader, load_model
|
from axolotl.utils.models import load_model
|
||||||
|
|
||||||
|
|
||||||
class TestModelsUtils:
|
class ModelsUtilsTest(unittest.TestCase):
|
||||||
"""Testing module for models utils."""
|
"""Testing module for models utils."""
|
||||||
|
|
||||||
def setup_method(self) -> None:
|
|
||||||
# load config
|
|
||||||
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"model_type": "LlamaForCausalLM",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"load_in_4bit": False,
|
|
||||||
"adapter": "lora",
|
|
||||||
"flash_attention": False,
|
|
||||||
"sample_packing": True,
|
|
||||||
"device_map": "auto",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
|
|
||||||
spec=PreTrainedTokenizerBase
|
|
||||||
)
|
|
||||||
self.inference = False # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.reference_model = True # pylint: disable=attribute-defined-outside-init
|
|
||||||
|
|
||||||
# init ModelLoader
|
|
||||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
ModelLoader(
|
|
||||||
cfg=self.cfg,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
inference=self.inference,
|
|
||||||
reference_model=self.reference_model,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_set_device_map_config(self):
|
|
||||||
# check device_map
|
|
||||||
device_map = self.cfg.device_map
|
|
||||||
if is_torch_mps_available():
|
|
||||||
device_map = "mps"
|
|
||||||
self.model_loader.set_device_map_config()
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
assert "device_map" not in self.model_loader.model_kwargs
|
|
||||||
else:
|
|
||||||
assert device_map in self.model_loader.model_kwargs["device_map"]
|
|
||||||
|
|
||||||
# check torch_dtype
|
|
||||||
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
|
|
||||||
|
|
||||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -81,38 +35,3 @@ class TestModelsUtils:
|
|||||||
"shifted-sparse attention does not currently support sample packing"
|
"shifted-sparse attention does not currently support sample packing"
|
||||||
in str(exc.value)
|
in str(exc.value)
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
|
|
||||||
@pytest.mark.parametrize("load_in_8bit", [True, False])
|
|
||||||
@pytest.mark.parametrize("load_in_4bit", [True, False])
|
|
||||||
@pytest.mark.parametrize("gptq", [True, False])
|
|
||||||
def test_set_quantization_config(
|
|
||||||
self,
|
|
||||||
adapter,
|
|
||||||
load_in_8bit,
|
|
||||||
load_in_4bit,
|
|
||||||
gptq,
|
|
||||||
):
|
|
||||||
# init cfg as args
|
|
||||||
self.cfg.load_in_8bit = load_in_8bit
|
|
||||||
self.cfg.load_in_4bit = load_in_4bit
|
|
||||||
self.cfg.gptq = gptq
|
|
||||||
self.cfg.adapter = adapter
|
|
||||||
|
|
||||||
self.model_loader.set_quantization_config()
|
|
||||||
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
|
||||||
assert not (
|
|
||||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
|
||||||
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
|
||||||
)
|
|
||||||
elif load_in_8bit and self.cfg.adapter is not None:
|
|
||||||
assert self.model_loader.model_kwargs["load_in_8bit"]
|
|
||||||
elif load_in_4bit and self.cfg.adapter is not None:
|
|
||||||
assert self.model_loader.model_kwargs["load_in_4bit"]
|
|
||||||
|
|
||||||
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
|
||||||
self.cfg.adapter == "lora" and load_in_8bit
|
|
||||||
):
|
|
||||||
assert self.model_loader.model_kwargs.get(
|
|
||||||
"quantization_config", BitsAndBytesConfig
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user