Compare commits

..

7 Commits

Author SHA1 Message Date
Wing Lian
3afc91fba9 run 2.5.1 test without waiting for 1st e2e 2024-12-07 17:25:16 -05:00
Wing Lian
0689419d25 use pr base tag 2024-12-07 17:25:16 -05:00
Wing Lian
e64c32c0bd push test build 2024-12-07 17:25:16 -05:00
Wing Lian
ec819dde3b attempt to build the test images 2024-12-07 17:25:16 -05:00
Wing Lian
fdf4bb5087 fix default base image 2024-12-07 17:25:16 -05:00
Wing Lian
f67d16268c try with default tag 2024-12-07 17:25:16 -05:00
Wing Lian
684b543aa1 experiment with nvcr pytorch image for torch 2.5.1 2024-12-07 17:25:16 -05:00
42 changed files with 321 additions and 1251 deletions

View File

@@ -22,36 +22,38 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "121" # - cuda: "121"
cuda_version: 12.1.1 # cuda_version: 12.1.1
cudnn_version: 8 # cudnn_version: 8
python_version: "3.10" # python_version: "3.10"
pytorch: 2.3.1 # pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" # torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121" # from_base_img: ""
cuda_version: 12.1.1 # from_base_tag: ""
cudnn_version: 8 # - cuda: "121"
python_version: "3.11" # cuda_version: 12.1.1
pytorch: 2.3.1 # cudnn_version: 8
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" # python_version: "3.11"
- cuda: "124" # pytorch: 2.3.1
cuda_version: 12.4.1 # torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
cudnn_version: "" # from_base_img: ""
python_version: "3.10" # from_base_tag: ""
pytorch: 2.4.1 # - cuda: "124"
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" # cuda_version: 12.4.1
- cuda: "124" # cudnn_version: ""
cuda_version: 12.4.1 # python_version: "3.11"
cudnn_version: "" # pytorch: 2.4.1
python_version: "3.11" # torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
pytorch: 2.4.1 # from_base_img: ""
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" # from_base_tag: ""
- cuda: "124" - cuda: "124"
cuda_version: 12.4.1 cuda_version: 12.4.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
from_base_img: nvcr.io/nvidia/pytorch
from_base_tag: 24.10-py3
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -61,7 +63,7 @@ jobs:
with: with:
images: | images: |
winglian/axolotl-base winglian/axolotl-base
axolotlai/axolotl-base # axolotlai/axolotl-base
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
with: with:
@@ -74,7 +76,8 @@ jobs:
with: with:
context: . context: .
file: ./docker/Dockerfile-base file: ./docker/Dockerfile-base
push: ${{ github.event_name != 'pull_request' }} push: true
# push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-args: | build-args: |
@@ -84,3 +87,5 @@ jobs:
PYTHON_VERSION=${{ matrix.python_version }} PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }} TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
BASE_IMAGE=${{ matrix.from_base_img || '' }}
BASE_TAG=${{ matrix.from_base_tag || '' }}

View File

@@ -13,13 +13,10 @@ jobs:
permissions: permissions:
contents: write contents: write
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Create release - name: Create release
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh release create "$GITHUB_REF_NAME" --generate-notes run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
pypi-publish: pypi-publish:
name: Upload release to PyPI name: Upload release to PyPI
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -41,7 +38,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install wheel packaging pip3 install wheel packaging
pip3 install --no-build-isolation -e . pip3 install -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name - name: Extract tag name

View File

@@ -60,15 +60,11 @@ jobs:
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging
pip3 install --no-build-isolation -U -e . pip3 install -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed - name: Ensure axolotl CLI was installed
run: | run: |
axolotl --help axolotl --help

View File

@@ -78,23 +78,19 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 show torch
pip3 install --no-build-isolation -U -e . pip3 install -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed - name: Ensure axolotl CLI was installed
run: | run: |
axolotl --help axolotl --help
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/ pytest tests/patched/
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |
@@ -124,7 +120,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -133,86 +129,83 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 show torch
python -m build --no-isolation --sdist python3 setup.py sdist
pip3 install --no-build-isolation dist/axolotl*.tar.gz pip3 install dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed - name: Ensure axolotl CLI was installed
run: | run: |
axolotl --help axolotl --help
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/ pytest tests/patched/
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st: # docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} # if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners... # # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] # runs-on: [self-hosted, modal]
timeout-minutes: 90 # timeout-minutes: 90
needs: [pre-commit, pytest, pytest-sdist] # needs: [pre-commit, pytest, pytest-sdist]
#
strategy: # strategy:
fail-fast: false # fail-fast: false
matrix: # matrix:
include: # include:
- cuda: 124 # - cuda: 124
cuda_version: 12.4.1 # cuda_version: 12.4.1
python_version: "3.11" # python_version: "3.11"
pytorch: 2.4.1 # pytorch: 2.4.1
num_gpus: 1 # num_gpus: 1
axolotl_extras: # axolotl_extras:
steps: # steps:
- name: Checkout # - name: Checkout
uses: actions/checkout@v4 # uses: actions/checkout@v4
- name: Install Python # - name: Install Python
uses: actions/setup-python@v5 # uses: actions/setup-python@v5
with: # with:
python-version: "3.10" # python-version: "3.10"
- name: Install Modal # - name: Install Modal
run: | # run: |
python -m pip install --upgrade pip # python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2 # pip install modal==0.63.64 jinja2
- name: Update env vars # - name: Update env vars
run: | # run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV # echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV # echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV # echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV # echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV # echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV # echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal # - name: Run tests job on Modal
run: | # run: |
modal run cicd.tests # modal run cicd.tests
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 90 timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st] # needs: [pre-commit, pytest, docker-e2e-tests-1st]
needs: [pre-commit, pytest]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 121 # - cuda: 121
cuda_version: 12.1.1 # cuda_version: 12.1.1
python_version: "3.10" # python_version: "3.10"
pytorch: 2.3.1 # pytorch: 2.3.1
num_gpus: 1 # num_gpus: 1
axolotl_extras: mamba-ssm # axolotl_extras: mamba-ssm
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -232,7 +225,7 @@ jobs:
pip install modal==0.63.64 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=pr-2139-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV

View File

@@ -1,5 +1,4 @@
include requirements.txt include requirements.txt
include README.md include README.md
include LICENSE include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
recursive-include axolotl *.py recursive-include axolotl *.py

104
README.md
View File

@@ -10,13 +10,9 @@
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License"> <img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a> <a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars"> <img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
<br/> </p>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a> <p align="center">
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p> </p>
@@ -46,8 +42,7 @@ Features:
- [Axolotl](#axolotl) - [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents) - [Table of Contents](#table-of-contents)
- [Quickstart ⚡](#quickstart-) - [Quickstart ⚡](#quickstart-)
- [Edge Builds](#edge-builds-) - [Usage](#usage)
- [Axolotl CLI Usage](#axolotl-cli-usage)
- [Badge ❤🏷️](#badge-) - [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-) - [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-) - [Sponsors 🤝❤](#sponsors-)
@@ -112,49 +107,58 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1. **Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash ```bash
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] git clone https://github.com/axolotl-ai-cloud/axolotl
# download examples and optionally deepspeed configs to the local path
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
# finetune using lora
axolotl train examples/llama-3/lora-1b.yml
```
### Edge Builds 🏎️
If you're looking for the latest features and updates between releases, you'll need to install
from source.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging ninja pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
``` ```
### Axolotl CLI Usage ### Usage
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). ```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
### Axolotl CLI
If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:
```bash ```bash
# preprocess datasets - optional but recommended # preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml
# finetune lora # finetune lora
axolotl train examples/llama-3/lora-1b.yml axolotl train examples/openllama-3b/lora.yml
# inference # inference
axolotl inference examples/llama-3/lora-1b.yml \ axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out" --lora-model-dir="./outputs/lora-out"
# gradio # gradio
axolotl inference examples/llama-3/lora-1b.yml \ axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out" --gradio --lora-model-dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL # remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml # Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
``` ```
We've also added a new command for fetching `examples` and `deepspeed_configs` to your We've also added a new command for fetching `examples` and `deepspeed_configs` to your
@@ -171,36 +175,6 @@ axolotl fetch deepspeed_configs
axolotl fetch examples --dest path/to/folder axolotl fetch examples --dest path/to/folder
``` ```
### Legacy Usage
<details>
<summary>Click to Expand</summary>
While the Axolotl CLI is the preferred method for interacting with axolotl, we
still support the legacy `-m axolotl.cli.*` usage.
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
# inference
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
</details>
## Badge ❤🏷️ ## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card. Building something cool with Axolotl? Consider adding a badge to your model card.
@@ -320,7 +294,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
3. Install Axolotl along with python dependencies 3. Install Axolotl along with python dependencies
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
``` ```
4. (Optional) Login to Huggingface to use gated models/datasets. 4. (Optional) Login to Huggingface to use gated models/datasets.
```bash ```bash
@@ -399,7 +373,7 @@ Please use WSL or Docker!
Use the below instead of the install method in QuickStart. Use the below instead of the install method in QuickStart.
``` ```
pip3 install --no-build-isolation -e '.' pip3 install -e '.'
``` ```
More info: [mac.md](/docs/mac.qmd) More info: [mac.md](/docs/mac.qmd)

View File

@@ -1,4 +1,4 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }} FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}" ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh RUN python scripts/unsloth_install.py | sh

View File

@@ -1,10 +1,7 @@
#!/bin/bash #!/bin/bash
set -e set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,5 +1,6 @@
ARG BASE_IMAGE=axolotlai/axolotl-base
ARG BASE_TAG=main-base ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base:$BASE_TAG FROM $BASE_IMAGE:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh RUN python scripts/unsloth_install.py | sh

View File

@@ -3,7 +3,10 @@ ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder ARG BASE_IMAGE=nvidia/cuda
ARG DEFAULT_TAG=${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_TAG=""
FROM ${BASE_IMAGE:-nvidia/cuda}:${BASE_TAG:-${DEFAULT_TAG}} AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
@@ -16,7 +19,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
cd flash-attention cd flash-attention
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
pip install --no-build-isolation . pip install .
``` ```
### 6. Install Axolotl ### 6. Install Axolotl
@@ -63,7 +63,7 @@ Clone and install Axolotl:
git clone https://github.com/axolotl-ai-cloud/axolotl git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl cd axolotl
pip install packaging ninja pip install packaging ninja
pip install --no-build-isolation -e . pip install -e .
``` ```
### 7. Apply xformers Workaround ### 7. Apply xformers Workaround

View File

@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
``` ```
#### Remote Hosts #### Remote Hosts
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
``` ```
### Attach To Container ### Attach To Container

View File

@@ -52,26 +52,6 @@ datasets:
type: chat_template.argilla type: chat_template.argilla
``` ```
#### KTO
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
remove_unused_columns: false
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
```
#### Using local dataset files #### Using local dataset files
```yaml ```yaml
datasets: datasets:

View File

@@ -24,7 +24,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install --no-build-isolation axolotl[deepspeed]" "!pip install axolotl[deepspeed]"
] ]
}, },
{ {

View File

@@ -1,58 +0,0 @@
base_model: NousResearch/Meta-Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,74 +0,0 @@
base_model: NousResearch/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,73 +0,0 @@
base_model: NousResearch/Meta-Llama-3.1-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,75 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/qlora-out
remove_unused_columns: false
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: false # not supported with kto
eval_sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,4 +1,4 @@
base_model: NousResearch/Llama-3.2-1B base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
@@ -22,6 +22,7 @@ pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
lora_target_modules: lora_target_modules:
- gate_proj - gate_proj

View File

@@ -17,10 +17,3 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git" Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools_scm] [tool.setuptools_scm]
[tool.setuptools]
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"

View File

@@ -1,30 +1,22 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.4.2
# END section
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers>=4.46.3 transformers==4.47.0
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.2.0 accelerate==1.2.0
datasets==3.1.0 datasets==3.1.0
deepspeed==0.16.1 deepspeed==0.15.4
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
flash-attn==2.7.0.post2
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.23.post1
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -39,6 +31,11 @@ art
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq==0.2.7.post2
triton>=2.3.0
liger-kernel==0.4.2
mamba-ssm==1.2.0.post1
# remote filesystems # remote filesystems
s3fs>=2024.5.0 s3fs>=2024.5.0

View File

@@ -13,5 +13,5 @@ cd /workspace
rm -rf /workspace/axolotl rm -rf /workspace/axolotl
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip install --no-build-isolation --no-deps -e . pip install --no-deps -e .
``` ```

View File

@@ -1,10 +1,7 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
import ast
import os
import platform import platform
import re import re
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from setuptools import find_packages, setup from setuptools import find_packages, setup
@@ -93,24 +90,9 @@ def parse_requirements():
return _install_requires, _dependency_links return _install_requires, _dependency_links
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
"r",
encoding="utf-8",
) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
return version_
install_requires, dependency_links = parse_requirements() install_requires, dependency_links = parse_requirements()
setup( setup(
version=get_package_version(),
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages("src"), packages=find_packages("src"),
install_requires=install_requires, install_requires=install_requires,
@@ -125,7 +107,7 @@ setup(
"flash-attn==2.7.0.post2", "flash-attn==2.7.0.post2",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.16.1", "deepspeed==0.15.4",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -1,3 +1,8 @@
"""Axolotl - Train and fine-tune large language models""" """Axolotl - Train and fine-tune large language models"""
__version__ = "0.6.0" try:
from importlib.metadata import version
__version__ = version("axolotl")
except ImportError:
__version__ = "unknown"

View File

@@ -5,7 +5,6 @@ from typing import Optional
import click import click
import axolotl
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -17,7 +16,6 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group() @click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli(): def cli():
"""Axolotl CLI - Train and fine-tune large language models""" """Axolotl CLI - Train and fine-tune large language models"""

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -974,13 +973,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs, start_time)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1172,13 +1165,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
@@ -1196,13 +1185,9 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
@@ -1247,13 +1232,9 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
@@ -1271,13 +1252,9 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): logs, start_time
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call )
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
@@ -1289,12 +1266,9 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method # TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call logs, start_time
logs, start_time )
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
@@ -1319,10 +1293,6 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"): if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"]) model.add_model_tags(["axolotl"])
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
# self.model =
@property @property
def model_ref(self): def model_ref(self):
return self._model_ref return self._model_ref
@@ -1372,6 +1342,8 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_mlflow and is_mlflow_available(): if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import ( from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,
) )
@@ -1379,6 +1351,7 @@ class TrainerBuilderBase(abc.ABC):
callbacks.extend( callbacks.extend(
[ [
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path), SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
] ]
) )
if self.cfg.use_comet and is_comet_available(): if self.cfg.use_comet and is_comet_available():

View File

@@ -1,80 +0,0 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -5,7 +5,8 @@ see https://github.com/huggingface/transformers/pull/35128
import inspect import inspect
import logging import logging
from transformers import LlamaForCausalLM, Trainer from transformers import LlamaForCausalLM
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code from axolotl.monkeypatch.unsloth_ import detab_code
@@ -204,87 +205,3 @@ def patch_forward_for_ga():
LlamaForCausalLM.forward = ( # pylint: disable=protected-access LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821 _fixed_forward # pylint: disable=undefined-variable # noqa: F821
) )
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -28,8 +28,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:return: :return:
""" """
max_length = self.prompter.max_length
self.messages = "chosen_messages" self.messages = "chosen_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
@@ -41,16 +39,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt) chosen_tokenized = super().tokenize_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
:max_length
]
self.messages = "rejected_messages" self.messages = "rejected_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
@@ -64,18 +52,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
) )
rejected_tokenized = super().tokenize_prompt(prompt) rejected_tokenized = super().tokenize_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
:max_length
]
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
:max_length
]
return { return {
"input_ids_chosen": chosen_tokenized["input_ids"], "input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"], "attention_mask_chosen": chosen_tokenized["attention_mask"],
@@ -104,9 +80,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"roles": ds_cfg.get("roles"), "roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False), "drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit. # we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": ( "max_length": cfg.sequence_len + 1
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len if not cfg.reward_model
), else cfg.sequence_len,
} }
strategy_params = { strategy_params = {

View File

@@ -42,7 +42,6 @@ class ChatTemplatePrompter(Prompter):
"gpt": "assistant", "gpt": "assistant",
"system": "system", "system": "system",
} }
self.message_field_role = message_field_role self.message_field_role = message_field_role
self.message_field_content = message_field_content self.message_field_content = message_field_content
self.message_field_training = message_field_training self.message_field_training = message_field_training
@@ -54,9 +53,21 @@ class ChatTemplatePrompter(Prompter):
self.drop_system_message = drop_system_message self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None): def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor: if self.processor:
text = self.processor.apply_chat_template( text = self.processor.apply_chat_template(
conversation, turns,
chat_template=self.chat_template, chat_template=self.chat_template,
tokenize=False, tokenize=False,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
@@ -65,6 +76,8 @@ class ChatTemplatePrompter(Prompter):
text=text, text=text,
images=images, images=images,
return_tensors="pt", return_tensors="pt",
truncation=True,
max_length=self.max_length,
) )
# workaround since processor works in batches instead of single examples # workaround since processor works in batches instead of single examples
for k, val in batch.items(): for k, val in batch.items():
@@ -75,7 +88,9 @@ class ChatTemplatePrompter(Prompter):
return batch return batch
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
conversation, turns,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template, chat_template=self.chat_template,
) )
@@ -200,14 +215,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None, train_on_eos=None,
): ):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.train_on_eos = train_on_eos self.train_on_eos = train_on_eos
self.images = "images" self.images = "images"
@@ -254,28 +262,30 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt return tokenized_prompt
turns = self.get_conversation_thread(prompt) turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids) labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1 last_eos_idx = -1
for index, turn in enumerate(turns): for index, turn in enumerate(turns):
role = turn.get("role") role = turn.get(self.prompter.message_field_role)
content = turn.get("content") content = turn.get(self.prompter.message_field_content)
train_turn = turn.get("training") train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get("training_detail") train_detail = turn.get(self.prompter.message_field_training_detail)
LOG.debug( LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}" f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
) )
should_train = None should_train = (
if train_turn is not None: train_turn
should_train = train_turn if train_turn is not None
elif train_detail is not None: else (
should_train = bool(train_detail) bool(train_detail is not None)
else: if train_detail is not None
should_train = self.train_on_inputs or role in self.roles_to_train else self.train_on_inputs or role in self.roles_to_train
)
)
LOG.debug(f"Should train: {should_train}") LOG.debug(f"Should train: {should_train}")
@@ -283,9 +293,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
conversation_ids=input_ids, turn=index, turn_content=turn conversation_ids=input_ids, turn=index, turn_content=turn
) )
if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1: if should_train and turn_start_idx != -1 and turn_end_idx != -1:
@@ -306,9 +313,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
labels[turn_start_idx:turn_end_idx] = input_ids[ labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx turn_start_idx:turn_end_idx
] ]
LOG.debug( LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
LOG.debug(f"Labels after processing turn {index}: {labels}") LOG.debug(f"Labels after processing turn {index}: {labels}")
@@ -346,73 +351,52 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i return i
return -1 return -1
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict): def find_turn(self, conversation_ids, turn, turn_content):
""" """
Locate the starting and ending indices of the specified turn in a conversation. Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
""" """
content = turn_content.get("content") content = turn_content.get(self.prompter.message_field_content, "")
content_ids = self.tokenizer.encode(content, add_special_tokens=False) content_ids = self.tokenizer.encode(content, add_special_tokens=False)
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}") eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
if not content_ids: # Locate the starting index after the specified number of EOS tokens
LOG.warning(f"Empty content for turn {turn}") for i, token_id in enumerate(conversation_ids):
return -1, -1 if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# For first turn, start from beginning # Find the start index of the content within the conversation
if turn == 0: start_idx = -1
start_search_idx = 0 for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else: else:
# For subsequent turns, find the previous EOS token end_idx = -1
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
for i, token_id in enumerate(conversation_ids): return start_idx, end_idx
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break
# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1
if last_index < start_search_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
)
return -1, -1
# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)
return -1, -1
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
turns = [ return prompt[self.messages]
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
]
if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return turns
def get_images(self, prompt): def get_images(self, prompt):
return prompt.get(self.images, None) return prompt.get(self.images, None)

View File

@@ -259,7 +259,14 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
try: try:
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)
model_card_kwarg = { model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./") "model_name": cfg.output_dir.lstrip("./")
.encode("utf-8") .encode("utf-8")
@@ -267,22 +274,16 @@ def train(
} }
if cfg.datasets is not None: if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model: if cfg.rl is not None or cfg.reward_model:
dataset_tags = [ model_card_kwarg["dataset_name"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
] ]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else: else:
dataset_tags = [ model_card_kwarg["dataset_tags"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
] ]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg) trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError): except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated # defensively push to the hub to ensure the model card is updated

View File

@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
default=None, json_schema_extra={"description": "transformers processor class"} default=None, json_schema_extra={"description": "transformers processor class"}
) )
trust_remote_code: Optional[bool] = None trust_remote_code: Optional[bool] = None
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
model_kwargs: Optional[Dict[str, Any]] = None model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code") @field_validator("trust_remote_code")
@@ -1475,27 +1475,6 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="before")
@classmethod
def check_kto_config(cls, data):
if data.get("rl") == "kto":
if data.get("sample_packing") or data.get("eval_sample_packing"):
raise ValueError("sample_packing is not supported with kto")
if data.get("remove_unused_columns") is not False:
raise ValueError("Set `remove_unused_columns: False` when using kto")
if data.get("gradient_checkpointing") and not (
data.get("gradient_checkpointing_kwargs")
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
):
raise ValueError(
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -380,19 +380,6 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
from axolotl.monkeypatch.trainer_grad_accum import (
patch_training_loop_for_deepspeed_0_16_x,
)
patch_training_loop_for_deepspeed_0_16_x()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
@@ -419,14 +406,10 @@ class ModelLoader:
and self.cfg.flash_attention and self.cfg.flash_attention
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
if "auto_map" in self.model_config: has_remote_code = (
try: "auto_map" in self.model_config
auto_map_config = self.model_config["auto_map"] and "AutoModelForCausalLM" in self.model_config["auto_map"]
except TypeError: )
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False: if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code has_remote_code = self.cfg.trust_remote_code
@@ -1187,15 +1170,9 @@ class ModelLoader:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.post_loading_set_env()
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return self.model, lora_config return self.model, lora_config
def post_loading_set_env(self):
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
def load_model( def load_model(
cfg: DictDefault, cfg: DictDefault,

View File

@@ -1,104 +0,0 @@
"""
dynamic requirements for axolotl
"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools.command.build_py import build_py as _build_py
# pylint: disable=duplicate-code
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
class BuildPyCommand(_build_py):
"""
custom build_py command to parse dynamic requirements
"""
def finalize_options(self):
super().finalize_options()
install_requires, _ = parse_requirements()
self.distribution.install_requires = install_requires

View File

@@ -1,10 +0,0 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli
def test_print_version(cli_runner):
"""Test that version is printed when --version is used."""
result = cli_runner.invoke(cli, ["--version"])
assert result.exit_code == 0
assert "axolotl, version " in result.output

View File

@@ -119,52 +119,25 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
original_fa2_forward = LlamaFlashAttention2.forward original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests # monkey patches can happen inside the tests
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches # Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [ modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama",), ("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
(
"transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
),
("transformers.trainer",), ("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",), ("transformers.loss.loss_utils",),
] ]
for module_name_tuple in modules_to_reset: for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0] module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
spec = importlib.util.spec_from_file_location( sys.modules[module_name] = module
module_name, sys.modules[module_name].__file__ importlib.reload(sys.modules[module_name])
)
sys.modules[module_name] = importlib.util.module_from_spec(spec)
spec.loader.exec_module(sys.modules[module_name])
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1: if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1] module_globals = module_name_tuple[1]
for module_global in module_globals: for module_global in module_globals:

View File

@@ -71,11 +71,7 @@ class TestCutCrossEntropyIntegration:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attention_type", "attention_type",
[ ["flash_attention", "sdp_attention", "xformers_attention"],
"flash_attention",
"sdp_attention",
# "xformers_attention",
],
) )
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type): def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -9,7 +9,6 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
@@ -54,7 +53,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -62,7 +61,6 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -85,13 +83,9 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 4],
) )
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps): def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -118,15 +112,14 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 1, "micro_batch_size": 4,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -149,10 +142,6 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_lora_ddp(self, temp_dir): def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
@@ -191,7 +180,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -200,7 +189,6 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -223,10 +211,6 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_qlora_ddp(self, temp_dir): def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
@@ -265,8 +249,8 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 2, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
"warmup_steps": 0, "warmup_steps": 0,
@@ -274,7 +258,6 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -297,13 +280,9 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 4],
) )
def test_fsdp(self, temp_dir, gradient_accumulation_steps): def test_fsdp(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -322,8 +301,8 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 10,
"micro_batch_size": 2, "micro_batch_size": 4,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
@@ -344,7 +323,6 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "FULL_STATE_DICT", "fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
}, },
"use_tensorboard": True,
} }
) )
@@ -367,10 +345,6 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_state_dict_type", "fsdp_state_dict_type",
["FULL_STATE_DICT", "SHARDED_STATE_DICT"], ["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
@@ -394,7 +368,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -416,7 +390,6 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": fsdp_state_dict_type, "fsdp_state_dict_type": fsdp_state_dict_type,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
}, },
"use_tensorboard": True,
} }
) )
@@ -439,10 +412,6 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir): def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
@@ -475,7 +444,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -497,7 +466,6 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
}, },
"use_tensorboard": True,
} }
) )
@@ -520,41 +488,12 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 4],
) )
@pytest.mark.parametrize( def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
"deepspeed",
[
"deepspeed_configs/zero3_bf16.json",
"deepspeed_configs/zero3_bf16_cpuoffload_all.json",
# "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero3_packed(
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -572,17 +511,15 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 1, "micro_batch_size": 4,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed), "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
**adapter,
} }
) )
@@ -605,35 +542,19 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard( def test_ds_zero3_qlora_packed(self, temp_dir):
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
if qlora: cfg = DictDefault(
adapter = { {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True, "sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 2048, "sequence_len": 2048,
"val_set_size": 0.05, "val_set_size": 0.05,
@@ -647,17 +568,15 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 15,
"micro_batch_size": 1, "micro_batch_size": 4,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.0001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
**adapter,
} }
) )
@@ -679,82 +598,3 @@ class TestMultiGPULlama:
str(Path(temp_dir) / "config.yaml"), str(Path(temp_dir) / "config.yaml"),
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
**adapter,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)

View File

@@ -4,6 +4,7 @@ E2E tests for lora llama
import logging import logging
import os import os
from importlib import reload
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -21,6 +22,14 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@pytest.fixture(autouse=True)
def reload_transformers():
import transformers.models.llama.modeling_llama
yield
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama: class TestFAXentropyLlama:
""" """
Test case for Llama models using LoRA w multipack Test case for Llama models using LoRA w multipack

View File

@@ -7,7 +7,6 @@ import os
import unittest import unittest
from pathlib import Path from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
@@ -22,7 +21,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("FIXME, mostly underused functionality")
class TestFusedLlama(unittest.TestCase): class TestFusedLlama(unittest.TestCase):
""" """
Test case for Llama models using Fused layers Test case for Llama models using Fused layers