Compare commits

...

43 Commits

Author SHA1 Message Date
Wing Lian
31799bdcc0 more parity across tests and docker images for packaging/setuptools 2025-03-21 08:56:01 -04:00
Wing Lian
25455ac25f make sure packaging version is consistent 2025-03-21 08:27:17 -04:00
Wing Lian
edea25bd58 comment out license for validation for now 2025-03-21 08:20:28 -04:00
Wing Lian
42e32223c9 try rolling back packaging and setuptools versions 2025-03-21 08:12:07 -04:00
Wing Lian
6e0fed0ce7 use license instead of license-file 2025-03-21 07:25:09 -04:00
Wing Lian
5ece44b4a8 try with reversion of packaging/setuptools/wheel install 2025-03-21 07:19:12 -04:00
Wing Lian
e7532c9b0c make sure ninja is installed 2025-03-21 06:57:06 -04:00
Wing Lian
2518a9b2a2 multiline fix 2025-03-20 20:51:16 -04:00
Wing Lian
faeae323cb install deepspeed by itself 2025-03-20 20:04:39 -04:00
Wing Lian
bb683644c3 deepspeed binary fixes hopefully 2025-03-20 19:52:07 -04:00
Wing Lian
7009a48398 bump deepspeed and set no binary 2025-03-20 14:01:01 -04:00
Wing Lian
ee529e2354 use nightly 2025-03-20 11:24:30 -04:00
Wing Lian
b2976e64ec add 12.8.1 cuda to the base matrix 2025-03-20 11:24:30 -04:00
Wing Lian
38df5a36ea bump HF versions except for trl (#2427) 2025-03-20 10:22:05 -04:00
Wing Lian
4d92a68a96 use default torch fused adamw optimizer as default as adamw_hf is deprecated (#2425)
* use default torch fused adamw optimizer as default as adamw_hf is deprecated

* make sure to have latest packaging installed

* bump packagingin requirements.txt too
2025-03-19 23:58:33 -04:00
SicariusSicariiStuff
85147ec430 Update README.md (#2360)
* Update README.md

wheel is needed

* feat: add ninja, setuptools, packing to installation steps

* fix: add missing instruction

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-03-17 08:39:17 -04:00
NanoCode012
51cd409488 Feat: minor docs improvements for RLHF and faq on embeddings (#2401) [skip ci]
* feat: add doc on shrink_embeddings and custom calling

* chore: rename inference doc

* fix: clarify same config is used for all cli

* chore: rearrange order inference qmd

* feat: add simpo to doc

* fix: update defaults

* feat: add rl configs to doc

* fix: ensure beta consistent with trl.beta

* fix: clarify about lora/fft

* chore: rename title

* chore: fix language

* feat: move config reference higher

* Update docs/getting-started.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/rlhf.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-17 08:39:04 -04:00
NanoCode012
7235123d44 chore(docs): add cookbook/blog link to docs (#2410) [skip ci] 2025-03-17 08:38:19 -04:00
Wing Lian
4f5eb42a73 remove reference to deprecated import (#2407) 2025-03-15 08:49:41 -04:00
Wing Lian
fbe54be6b8 only validate hf user token on rank 0 (#2408) 2025-03-13 23:29:06 -04:00
Wing Lian
04f6324833 build cloud images with torch 2.6.0 (#2413)
* build cloud images with torch 2.6.0

* nightlies too
2025-03-13 23:28:51 -04:00
Wing Lian
f0072f3b9d use max of 32 dataset processes if not explicit (#2403)
* use max of 32 dataset processes if not explicit

* change alternate min val for consistency
2025-03-11 12:02:58 -04:00
Wing Lian
59899b9817 pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading

* use latest version of vendored lib

* use v0.0.5 of contribs lgpl

* fix for no bad tokens and add tests

* use release

* add multigpu test too

* make sure the multigpu zero3 test actually uses zero3
2025-03-11 12:02:43 -04:00
NanoCode012
4a736986fa fix(modal): add git pull when getting branch files (#2399) 2025-03-10 15:14:41 -04:00
Wing Lian
5d0f110a3b include iproute2 and nvtop in cloud image (#2393) 2025-03-10 15:13:38 -04:00
NanoCode012
83f8698b8a fix: create mount folder on modal if not exist (#2390) 2025-03-10 16:27:42 +07:00
xzuyn
60a11a6410 Use Latest Cut Cross Entropy (#2392)
* Update __init__.py

* Update README.md

* Update cutcrossentropy_install.py

* add test
2025-03-10 16:26:40 +07:00
NanoCode012
46a045e528 chore(doc): add faq when having no default chat_template (#2398)
* chore(doc): add faq when having no default chat_template

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/faq.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:50 +07:00
NanoCode012
3b477e08a0 feat(doc): add more info on RewardModel datasets (#2391)
* fix: reduce title size

* feat(doc): add rm dataset info

* Update docs/reward_modelling.qmd following suggestion

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:31 +07:00
NanoCode012
16dc6ee68d refactor: trl grpo configs to have descriptions (#2386)
* refactor: trl grpo configs to have descriptions

* chore: caps
2025-03-07 08:58:53 -05:00
Wing Lian
fa7c79b3b9 remove lion-pytorch as it's already handled upstream (#2389) 2025-03-07 08:58:15 -05:00
Wing Lian
ae66374156 Optimizer refactor and add Muon support (#2367)
* add muon optimizer

optimizer_cls_and_kwargs is on trainer_kwargs
only add adamw_kwargs if they're non-null
fix mocks
better handling of override and check the optimizer
unwrap optimizer

* fix import
2025-03-06 11:49:19 -05:00
Wing Lian
5e21b1a9da various fixes 20250305 (#2384)
* various validation fixes

* fix check for non-truthy value
2025-03-06 11:48:44 -05:00
mhenrichsen
575e5f28ec Update Tokenizer Overrides Handling in models.py (#1549)
* override special tokens mock code

* fix(doc): remove duplicate config

* feat: replace added_tokens in tokenizer and add test

* make sure to run tokenizer modification on rank 0 only

* use is local main process instead

* feat: rename config

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-05 11:15:12 -05:00
xzuyn
0134093acc Add REX LR Scheduler (#2380)
* Update trainer_builder.py

* Update base.py

* Update __init__.py

* Update base.py

* Update base.py

* Update config.qmd

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* lint

* lint

* lint

* lint

* lint

* lint

* Update base.py

* Update base.py

* lint

* Update base.py

* Update base.py

* Move RexLR to `schedulers.py`

* Remove RexLR from `base.py`

* Fix tooltip formatting

* lint

* Create test_schedulers.py

* Use a default optimizer in test

* lint

* lint

* Add `warmup_steps` and `cosine_min_lr_ratio` to test

* lint
2025-03-05 10:26:11 -05:00
NanoCode012
d4de93a7bb feat(grpo): add reward_weights config and refactor (#2365) 2025-03-05 10:02:08 -05:00
NanoCode012
c8191394e9 fix(doc): add missing low_cpu_mem_usage config to docs (#2369) [skip ci] 2025-03-05 10:01:44 -05:00
NanoCode012
f18231c653 chore(doc): add clarification about mpi4py error on single gpu deepspeed (#2383) [skip ci]
* chore(doc): add clarification about mpi4py error on single gpu deepspeed

* fix: lint
2025-03-05 10:01:28 -05:00
NanoCode012
9ed4f6b3aa feat(doc): document drop_system_message and clarify limitation (#2381) [skip ci] 2025-03-05 10:01:16 -05:00
NanoCode012
05dddfc41d feat(doc): add docker images explanation (#2379) [skip ci]
* feat(doc): add docker images explanation

* chore: add link to dockerhub
2025-03-05 10:01:00 -05:00
NanoCode012
8e30917440 chore(docs): remove phorm (#2378) [skip ci] 2025-03-05 10:00:50 -05:00
NanoCode012
d883b11b6f fix(doc): add installation for cce to docs (#2375) [skip ci]
* fix(doc): add installation for cce to docs

* fix: format
2025-03-05 10:00:39 -05:00
Dan Saunders
f4910dd2ea train.py refactor (#2371)
* refactor train.py

* updates

* update

* combine like functions

* review comments
2025-03-05 08:58:33 -05:00
60 changed files with 1704 additions and 433 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -61,7 +67,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/Dockerfile-base
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -88,6 +88,11 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -80,6 +80,11 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging
pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -42,7 +42,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -59,7 +59,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh

View File

@@ -74,7 +74,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -147,7 +147,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch
run: |

View File

@@ -22,8 +22,8 @@ repos:
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v3.3.0
- repo: https://github.com/pylint-dev/pylint
rev: c8c96d20cde3552a79858c7456bb1483bf83d633
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -19,9 +19,6 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline post-training for various AI models.
@@ -58,6 +55,7 @@ Features:
### Installation
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs

View File

@@ -32,14 +32,16 @@ website:
contents:
- docs/getting-started.qmd
- docs/installation.qmd
- docs/cli.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Deployments"
contents:
- docs/docker.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/ray-integration.qmd
@@ -73,10 +75,6 @@ website:
- docs/debugging.qmd
- docs/nccl.qmd
- section: "Reference"
contents:
- docs/config.qmd
format:
html:
theme: darkly

View File

@@ -31,6 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -0,0 +1,39 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -1,5 +1,5 @@
---
title: Config options
title: Config Reference
description: A complete list of all configuration options.
---
@@ -30,6 +30,8 @@ tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# (Internal use only)
# Used to identify which the model is based on
@@ -154,8 +156,6 @@ datasets:
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
@@ -163,6 +163,12 @@ datasets:
system: ["system"]
tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
@@ -201,10 +207,46 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto'
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
rl_beta: # Optional[float]. The beta parameter for the RL training.
# dpo
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
# orpo
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
# kto
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
# simpo
cpo_alpha: 1.0 # Weight of the BC regularizer
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
# reward modelling: `True` or `False`
reward_model:
@@ -222,13 +264,13 @@ process_reward_model:
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub
push_dataset_to_hub: # repo path
push_dataset_to_hub: # Optional[str] repo_org/repo_name
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
@@ -445,7 +487,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -528,6 +570,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
@@ -548,6 +592,13 @@ special_tokens:
# Add extra tokens.
tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:

View File

@@ -74,6 +74,10 @@ datasets:
train_on_eos:
```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml

View File

@@ -129,6 +129,7 @@ You can mix and match within each approach or across approaches to train a model
We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.

140
docs/docker.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -19,12 +19,24 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
@@ -50,3 +62,7 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
```yaml
base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
load_in_8bit: true
adapter: lora
datasets:
- path: teknium/GPT4-LLM-Cleaned
@@ -44,11 +46,15 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
```
::: {.callout-tip}
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
- To perform Full finetuning, remove these two lines.
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
:::
See our [Config options](config.qmd) for more details.
### Training {#sec-training}
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
When you run `axolotl train`, Axolotl:
1. Downloads the base model
2. (If specified) applies LoRA adapter layers
2. (If specified) applies QLoRA/LoRA adapter layers
3. Loads and processes the dataset
4. Runs the training loop
5. Saves the trained model and / or LoRA weights
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1
load_in_8bit: true
adapter: lora
# Training settings
@@ -104,8 +112,6 @@ format):
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training:
```bash

View File

@@ -1,5 +1,5 @@
---
title: "Inference"
title: "Inference and Merging"
format:
html:
toc: true
@@ -9,10 +9,14 @@ execute:
enabled: false
---
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
## Quick Start {#sec-quickstart}
::: {.callout-tip}
Use the same config used for training on inference/merging.
:::
### Basic Inference {#sec-basic}
::: {.panel-tabset}

View File

@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
@@ -37,7 +38,7 @@ For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
@@ -65,6 +66,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}
@@ -105,7 +108,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install packaging
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:

View File

@@ -66,6 +66,10 @@ logic to be compatible with more of them.
</details>
::: {.callout-tip}
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
:::
## Usage
These optimizations can be enabled in your Axolotl config YAML file. The

View File

@@ -28,8 +28,23 @@ val_set_size: 0.1
eval_steps: 100
```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM)
::: {.callout-tip}
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
:::
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B
@@ -45,3 +60,5 @@ datasets:
val_set_size: 0.1
eval_steps: 100
```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
@@ -297,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml
rl: ipo
@@ -343,8 +344,9 @@ ORPO supports the following types with the following dataset format:
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
rl_beta: 0.1 # default
kto_desirable_weight: 1.0 # default
kto_undesirable_weight: 1.0 # default
remove_unused_columns: false
@@ -496,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
@@ -528,6 +534,7 @@ trl:
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
@@ -536,6 +543,21 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
```yaml
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainer
```
This method uses the same dataset format as [DPO](#dpo).
### Using local dataset files
```yaml

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta"
[project]
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
# license = "Apache-2.0"
[project.scripts]
axolotl = "axolotl.cli.main:main"

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.2
bitsandbytes==0.45.3
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
@@ -12,12 +12,12 @@ liger-kernel==0.5.3
packaging==23.2
peft==0.14.0
peft==0.15.0
transformers==4.49.0
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.4.1
deepspeed==0.16.4
trl==0.15.1
optimum==1.16.2
@@ -62,4 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3

View File

@@ -17,12 +17,12 @@ if v < V("2.4.0"):
cce_spec = importlib.util.find_spec("cut_cross_entropy")
UNINSTALL_PREFIX = ""
uninstall_prefix = ""
if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
uninstall_prefix = "pip uninstall -y cut-cross-entropy && "
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
uninstall_prefix
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
)

View File

@@ -128,7 +128,7 @@ setup(
"flash-attn==2.7.4.post1",
],
"deepspeed": [
"deepspeed==0.16.1",
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
]
)
@@ -270,6 +270,7 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
@@ -288,6 +289,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"

View File

@@ -1,6 +1,7 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -34,18 +35,20 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
del trainer
plugin_manager.post_train_unload(cfg)

View File

@@ -24,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
eval_dataset: Dataset | None = None
total_num_steps: int | None = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
num_proc = min(32, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,

View File

@@ -35,6 +35,7 @@ from transformers import (
EarlyStoppingCallback,
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
@@ -84,6 +85,7 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype
try:
@@ -91,13 +93,11 @@ try:
except ImportError:
pass
LOG = logging.getLogger("axolotl.core.trainer_builder")
LOG = logging.getLogger(__name__)
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
"""Base class for trainer builder."""
_train_dataset = None
_eval_dataset = None
@@ -110,9 +110,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instad of trainer.push_to_hub.
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
@@ -227,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
"""
def get_callbacks(self):
@@ -551,30 +551,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
@@ -658,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_arguments_kwargs.get("learning_rate"),
"weight_decay": training_arguments_kwargs.get("weight_decay"),
}
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
# Adam-specific kwargs
adam_kwargs = {}
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
)
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
else:
# Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
@@ -872,9 +918,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()

View File

@@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
@@ -22,9 +23,11 @@ from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
@@ -115,6 +118,17 @@ class SchedulerMixin(Trainer):
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -154,47 +168,18 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
class OptimizerMixin(Trainer):
"""
Extend the base Trainer for axolotl helpers
Mixin class for shared handling of building custom optimizers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
params: dict = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
@@ -281,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
@@ -314,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -366,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -31,30 +32,44 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
if not hasattr(cfg, "trl") or not cfg.trl:
return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
] = trl.vllm_gpu_memory_utilization
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
if trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.ref_model_mixup_alpha:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
if trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
return grpo_args_kwargs
@classmethod

View File

@@ -23,6 +23,8 @@ import importlib
import logging
from typing import OrderedDict
import torch
class BasePlugin:
"""
@@ -469,3 +471,14 @@ class PluginManager:
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -4,6 +4,22 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage
```yaml

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
)

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class SpectrumArgs(BaseModel):
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -1,26 +1,29 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import inspect
import os
import signal
import sys
import weakref
from pathlib import Path
from typing import Tuple, Union
from typing import Any, Dict
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from datasets import Dataset
from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
@@ -32,17 +35,25 @@ try:
except ImportError:
BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger(__name__)
def train(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
def setup_model_and_tokenizer(
cfg: DictDefault,
) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -55,11 +66,58 @@ def train(
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
# Get datasets
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load the model and peft_config
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
return model, tokenizer, peft_config, processor
def setup_reference_model(
cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> PreTrainedModel | None:
"""
Set up the reference model for RL training if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to use for the reference model.
Returns:
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
return model_ref
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
@@ -73,77 +131,22 @@ def train(
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
return cfg.resume_from_checkpoint
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
"""
Set up signal handler for graceful termination.
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg,
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
if cfg.fix_untrained_tokens:
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
# ray workers don't have access to this signal
if cfg.local_rank == 0 and not cfg.use_ray:
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
@@ -161,21 +164,22 @@ def train(
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
):
"""
Execute the training process with appropriate backend configurations.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -187,15 +191,30 @@ def train(
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
# post training
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type:
@@ -203,16 +222,18 @@ def train(
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
# Handle ReLoRA early return case
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer
return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
# processes attempt to write the same file
if (
state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
@@ -244,7 +265,6 @@ def train(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
@@ -255,58 +275,241 @@ def train(
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""
Create a model card for the trained model if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object with model card creation capabilities.
"""
if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
try:
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
if cfg.datasets is not None and not rl:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags:
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
# Defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()
return model, tokenizer
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
peft_config: PeftConfig | None,
):
"""
Save initial configurations before training.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to save.
model: The model to save configuration for.
peft_config: The PEFT configuration to save if applicable.
"""
# Create output_dir if it doesn't already exist
output_dir = Path(cfg.output_dir)
if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
# Pre-save adapter config so it's available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
peft_config.save_pretrained(cfg.output_dir)
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
def pretrain_hooks(_cfg, _trainer):
def setup_model_card(cfg: DictDefault):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
def post_train_hooks(_cfg, _trainer):
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
Apply fixes for untrained tokens if configured.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
if "is_ds_zero3" in sig.parameters:
fix_kwargs["is_ds_zero3"] = is_ds_zero3
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
]:
"""
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
trainer setup.
Args:
cfg: The configuration dictionary with training parameters.
dataset_meta: Object with training, validation datasets and metadata.
Returns:
Tuple of:
- Trainer (Causal or RLHF)
- Model
- Tokenizer
- PEFT config
"""
# Load tokenizer, processor and model
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer)
# Get datasets from metadata
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
model_ref=model_ref,
peft_config=peft_config,
)
return (
trainer,
model,
tokenizer,
peft_config,
)
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
"""
Train a model on the given dataset.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Object with training, validation datasets and metadata
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
tokenizer,
peft_config,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
return model, tokenizer, trainer

View File

@@ -1,4 +1,5 @@
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
import logging
@@ -64,6 +65,17 @@ class ChatTemplate(str, Enum):
metharme = "metharme" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -494,18 +506,8 @@ class HyperparametersConfig(BaseModel):
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF
Union[OptimizerNames, CustomSupportedOptimizers]
] = OptimizerNames.ADAMW_TORCH_FUSED
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -518,7 +520,7 @@ class HyperparametersConfig(BaseModel):
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"]]
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
@@ -727,7 +729,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
@@ -778,9 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -855,6 +857,7 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
added_tokens_overrides: Optional[Dict[int, str]] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
@@ -1153,6 +1156,15 @@ class AxolotlInputConfig(
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:
@@ -1177,6 +1189,13 @@ class AxolotlInputConfig(
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@model_validator(mode="before")
@classmethod
def check_lr_groups(cls, data):
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_saves(cls, data):
@@ -1809,6 +1828,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data["torch_compile"] = False
return data
@model_validator(mode="before")
@classmethod
def check_beta_and_trl_beta_match(cls, data):
if data.get("beta") and data.get("trl", {}).get("beta"):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
def handle_legacy_message_fields_logic(data: dict) -> dict:
"""

View File

@@ -1,7 +1,8 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel, Field
@@ -11,7 +12,10 @@ class TRLConfig(BaseModel):
Input args for TRL.
"""
beta: Optional[float] = None
beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
@@ -20,16 +24,68 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
log_completions: Optional[bool] = False
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64
reward_funcs: Optional[list[str]] = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
)
reward_weights: Optional[list[float]] = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)

View File

@@ -2,6 +2,7 @@
import functools
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
@@ -344,6 +345,7 @@ def load_tokenized_prepared_datasets(
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(

View File

@@ -79,7 +79,7 @@ def is_main_process():
def is_local_main_process():
return PartialState().is_main_process
return PartialState().is_local_main_process
def get_world_size():

View File

@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
)
def hf_grad_checkpoint_unsloth_wrapper(
def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(

View File

@@ -24,7 +24,6 @@ from peft import (
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import nn
from transformers import ( # noqa: F401
AddedToken,
@@ -57,8 +56,14 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.distributed import (
barrier,
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -165,7 +170,95 @@ def load_model_config(cfg):
return model_config
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
import json
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -180,8 +273,18 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(
cfg.tokenizer_config,
tokenizer_path,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
@@ -389,8 +492,8 @@ class ModelLoader:
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.flash_attention:
self.patch_attention()
@@ -1256,7 +1359,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if (

View File

@@ -6,6 +6,80 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class RexLR(LRScheduler):
"""
Reflected Exponential (REX) learning rate scheduler.
- Original implementation: https://github.com/IvanVassi/REX_LR
- Original license: Apache 2.0
- Based on: https://arxiv.org/abs/2107.04197
Args:
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
max_lr (float): The maximum learning rate.
min_lr (float): The minimum learning rate.
total_steps (int): The total number of training steps.
num_warmup_steps (int): The number of warmup steps.
last_step (int): The index of last step.
"""
def __init__(
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
):
if min_lr > max_lr:
raise ValueError(
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
)
if num_warmup_steps > total_steps:
raise ValueError(
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
)
self.min_lr = min_lr
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)
@property
def last_step(self):
return self.last_epoch
@last_step.setter
def last_step(self, value):
self.last_epoch = value
def get_lr(self):
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
if 1 <= self.last_step <= self.num_warmup_steps:
return [
base_lr * self.last_step / self.num_warmup_steps
for base_lr in self.base_lrs
]
# Post-warmup phase: adjust step relative to the end of warmup.
step_after = self.last_step - self.num_warmup_steps
remaining_steps = self.total_steps - self.num_warmup_steps
# Avoid LR spiking
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
return [self.min_lr for _ in self.base_lrs]
mod_iter = step_after % remaining_steps
z = (remaining_steps - mod_iter) / remaining_steps
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
z / (0.1 + 0.9 * z)
)
return [base_lr * rex_factor for base_lr in self.base_lrs]
class InterpolatingLogScheduler(LRScheduler):
"""
A scheduler that interpolates learning rates in a logarithmic fashion

View File

@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
):
"""
Helper method for instantiating and building a (causal or RLHF) trainer.
Args:
cfg: Axolotl config object containing training parameters.
train_dataset: Dataset to use for training.
eval_dataset: Dataset to use for evaluation.
model: The model to train.
tokenizer: Tokenizer for processing text input.
processor: Processor for data preparation.
total_num_steps: The total number of training steps.
model_ref: Optional reference model for RLHF training. Default is None.
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref
trainer_builder.peft_config = peft_config
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -14,7 +14,7 @@
h1 {
font-family: var(--font-title);
font-weight: 400;
font-size: 6rem;
font-size: 5rem;
line-height: 1.1;
letter-spacing: -0.05em;
font-feature-settings: "ss01" on;

View File

@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,

View File

@@ -108,6 +108,12 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
)
@pytest.fixture(scope="session", autouse=True)
def download_tiny_shakespeare_dataset():
# download the dataset
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
@pytest.fixture
def temp_dir():
# Create a temporary directory

View File

@@ -69,6 +69,51 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
[

View File

@@ -750,3 +750,66 @@ class TestMultiGPULlama:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -40,8 +40,8 @@ class TestReLoraLlama(unittest.TestCase):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 100,
"relora_warmup_steps": 20,
"relora_steps": 50,
"relora_warmup_steps": 10,
"relora_anneal_steps": 10,
"relora_prune_ratio": 0.9,
"relora_cpu_offload": True,
@@ -60,9 +60,9 @@ class TestReLoraLlama(unittest.TestCase):
"message_field_content": "value",
},
],
"warmup_steps": 20,
"warmup_steps": 10,
"num_epochs": 2,
"max_steps": 205, # at least 2x relora_steps
"max_steps": 105, # at least 2x relora_steps
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,

View File

@@ -66,6 +66,54 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{

View File

@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32

View File

@@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
@with_temp_dir
@require_torch_2_5_1
@@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_5_1
def test_muon(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):

View File

@@ -0,0 +1,71 @@
"""
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomSchedulers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_rex_scheduler(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_hf",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -7,13 +7,13 @@ import tempfile
import unittest
from pathlib import Path
from conftest import snapshot_download_w_retry
from constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
@@ -69,7 +69,7 @@ class TestDatasetPreparation(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
snapshot_download_w_retry(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
@@ -81,7 +81,7 @@ class TestDatasetPreparation(unittest.TestCase):
# how to load it.
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"datasets": [
{
@@ -339,7 +339,7 @@ class TestDatasetPreparation(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
snapshot_download_w_retry(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
@@ -381,7 +381,7 @@ class TestDatasetPreparation(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
snapshot_download_w_retry(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,

View File

@@ -1,6 +1,7 @@
"""
Test cases for the tokenizer loading
"""
import unittest
import pytest
@@ -9,7 +10,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers(unittest.TestCase):
class TestTokenizers:
"""
test class for the load_tokenizer fn
"""
@@ -75,12 +76,48 @@ class TestTokenizers(unittest.TestCase):
}
)
tokenizer = load_tokenizer(cfg)
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
assert len(tokenizer) == 32001
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {
128041: "RANDOM_OVERRIDE_1",
128042: "RANDOM_OVERRIDE_2",
},
"output_dir": temp_dir,
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
128041
]
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
128042
]
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}
)
with pytest.raises(
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
):
load_tokenizer(cfg)
if __name__ == "__main__":