Compare commits

...

40 Commits

Author SHA1 Message Date
Dan Saunders
cbcc795bb3 commenting out unused 2025-06-16 01:53:13 +00:00
Dan Saunders
e34b6f4dfe temp: trying another approach 2025-06-15 21:32:10 +00:00
Dan Saunders
f8f87321bd progress 2025-06-14 17:40:21 +00:00
Dan Saunders
7a88de4fa8 finish basic impl; change naming from SP -> CP to match torch 2025-06-13 09:51:06 -04:00
Dan Saunders
aced809989 progress (messy :O) 2025-06-12 18:54:41 +00:00
Dan Saunders
ae73123eae progress; move validation to pydantic model config 2025-06-07 06:58:59 +00:00
Dan Saunders
10d1e44943 SDPA context parallel 2025-06-06 00:34:12 +00:00
Wing Lian
7909bfb076 add manual seed for flaky test_geglu_backward test (#2763) [skip ci] 2025-06-05 09:23:17 -07:00
Wing Lian
cb03c765a1 add uv tooling for e2e gpu tests (#2750)
* add uv tooling for e2e gpu tests

* fixes from PR feedback

* simplify check

* fix env var

* make sure to use uv for other install

* use raw_dockerfile_image

* Fix import

* fix args to experimental dockerfile image call

* use updated modal versions
2025-06-05 07:25:06 -07:00
Timofey Klyubin
4440b4a1ce remove unused field for chat_template.default for DPO training (#2755) [skip ci]
* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

* lint and fix tests for new return value

* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

lint and fix tests for new return value

fix for updated expected fields for dpo

remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

fix test still expecting "messages" field

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-05 07:22:58 -07:00
NanoCode012
e8e45b3441 fix: remove hqq (#2759) [skip ci] 2025-06-05 07:22:23 -07:00
Wing Lian
c67910fa6f bump hf deps (#2735) [skip ci]
* bump hf deps

* upgrade liger-kernel too

* install cce from fork for transformers fix

* fix reference to vocab size in gemma3 patch

* use padding_idx instead of pad_token_id

* remove fixed gemma3 patch

* use updated cce fork

* fix local mllama cce patches w docstring

* add test for multipack with trainer setup and fix trainer for trainer refactor upstream

* bump modal version

* guard for iterable datasetS

* mllama model arch layout changed in latest transformers

* fix batch sampler with drop_last

* fix: address upstream vlm changes for lora

* fix: update references to old lora target path

* fix: remove mllama fa2 patch due to upstream fix

* fix: lora kernel patch path for multimodal models

* fix: removed mllama from quarto

* run test for came optim on 2.6.0+

* fix fsdp2 patch and remove deprecated patch

* make sure to set sequence_parallel_degree for grpo

* Add SP test for GRPO

* add sp to grpo config for trainer

* use reward_funcs as kwarg to grpo trainer

* fix the comprehension for reward funcs

* reward funcs already passed in as args

* init sp_group right before training

* fix check for adding models to SP context

* make sure to pass args to super

* upgrade deepspeed

* use updated trl and add reasoning flags for vllm

* patch the worker

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-06-05 07:20:33 -07:00
NanoCode012
787880215b fix(deepspeed): deepspeed config not being set for z3 (#2754)
* fix(deepspeed): deepspeed config not being set for z3

* fix: comments
2025-06-03 14:27:09 -07:00
NanoCode012
4b1a29c694 feat(modal): update docker tag to use torch2.6 from torch2.5 (#2749) [skip ci] 2025-06-03 14:26:07 -07:00
NanoCode012
d7fa60662e feat: add chat_template kwargs (#2694) [skip ci] 2025-06-03 14:25:26 -07:00
Dan Saunders
1d91d905c9 remove deprecated wandb env var (#2751)
* remove deprecated wandb env var

* remove os.environ wandb setting; unused loggers

* remove os.environ wandb setting; unused loggers
2025-06-03 14:04:15 -07:00
mhenrhcsen
2bf61d8e25 fix abbriviatation spelling error 2025-06-03 21:30:40 +02:00
mhenrhcsen
68788e419e feat: add Group Relative Policy Optimization (GPRO) to RLHF documentation 2025-06-03 21:30:40 +02:00
github-actions[bot]
94219f6ee8 chore: update pre-commit hooks (#2745)
* chore: update pre-commit hooks

* trigger linter when pre commit hooks are updated

* fix type checks from upgraded pre-commit

---------

Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-02 15:54:29 -07:00
Wing Lian
ecc719f5c7 add support for base image with uv (#2691) 2025-06-02 12:48:55 -07:00
NanoCode012
d5d0dc5938 fix: suppress non-axolotl logs unless it's warning or higher (#2724)
* fix: increase log level for root loggers and axolotl's

* fix: BasePlugin using wrong logger

* fix: update logger to take name from module

* feat: change logger class to AxolotlLogger to filter non-axolotl infos or below

* fix: change behavior to not disable existing loggers

* fix: update logging to respect correct env

* chore: fix comment

* fix: suppress accelerate log to LOG_LEVEL if not set

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-05-31 12:13:43 +07:00
NanoCode012
5e86c35322 fix(log): remove duplicate merge_lora param (#2742) [skip ci] 2025-05-31 12:13:31 +07:00
NanoCode012
6778856804 Fix: RL base feature parity (#2133)
* feat: add num_proc and load from cache for rl mapping

* fix: refactor sft and rl trainer to set same base args

* feat: add report_to to set run name

* fix: consolidate handling of fp16, bf16, tf32 kwarg

* chore: consolidate eval_strat, loraplus, lr sched, max_length

* fix: deprecate old types

* fix: adding missing Any

* fix: max_steps incorrectly set

* fix: remove unnecessary datacollator kwarg insert and pop

* fix: update default max_steps

* fix: add missing weight_decay handling

* fix: ignore max_length for grpo

* feat: update CI on trainer_builder

* fix: comments

* improve handling of warmup/logging steps

* use transformers default for logging steps, not None

* fix: remove redundant override

* fix: lint

* feat: allow custom optim for rl methods

* fix: duplicate optim setting

* fix(test): set sequence_parallel_degree default in base cfg

* feat: add handling for seed and SP/ring-attn config

* chore: add back return typing from rebase

* fix(test): use RLType directly to skip needing to validate

* feat: split training builder into sub modules

* fix: remove deprecated clause

* chore: add missing config to doc

* fix: update quarto autodoc

* fix: import path for trainer builder and submodules

* fix: remove redundant configs from rebase mistake

* chore: simplify dynamo check

* fix: optimizer_cls_and_kwargs to be passed into trainer_kwargs

* fix: add missing rex from rebase

* fix: move pop optimizer_cls_and_kwargs

* fix: pop optimizer cls in rl too

* fix: leftover bug from rebase

* fix: update handling of trainer_cls in RL

* fix: address pr feedback

* feat: call hook_pre_create_trainer for rl

* chore: lint

* fix: return notimplemented for ppo

* feat: moved torch compile to base and refactor collator setting

* chore: remove unused importlib.util import

* fix: optimizer cls not being popped

* feat: move epoch setting to base

* fix: catch unhandled custom optimizer

* fix: remove duplicate lora plus setting

* chore: refactor if condition

* chore: refactor set_base_training_args into smaller modules

* fix: address TrainerBuilderBase class variables to instance var

* fix: add handling for beta3 and episilon2

* fix: change to pass dict via arg instead of updating dict

* chore: simplify if condition

* fix: force access to lr & weight decay in case not provided to early error

* fix: remove log sweep

* chore: refactor if condition

* fix: address renamed cfg

* fix: improve handling of cosine hyp

* fix: remove unused params

* chore: refactor

* chore: clarify doc safetensors

* fix: update import path to be unified following comments

* fix: duplicate kwargs passed

* feat: return separate trainer_kwargs

* chore: refactor

* chore: refactor based on comments

* chore: refactor based on comments

* fix: move gpustats callback to base

* chore: create trainer_cls_args first based on comments

* fix: ipo label smoothing passed incorrectly

* feat: add optimizer parity for RL methods with test

* feat: add parity for optimizer in RM/PRM and add test

* fix: remove redundant function override for orpo/cpo batch metrics

* fix: improve handling of dpo_label_smoothing and merge issue

* fix: test fixture returning wrong field

* fix: address avoid direct modify fixture

* chore: minor refactor

* Revert "chore: refactor"

This reverts commit 99c8859eb0.

* feat: rename trainer_builder to builders

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-30 11:21:47 +07:00
Wing Lian
ec4ebfd997 Add a few items to faq (#2734)
* Add a few items to faq

* formatting

* chore: lint
2025-05-28 16:20:19 -04:00
Dan Saunders
bde8b5b6bd fix dist state init before deepspeed setup (#2737) 2025-05-28 14:59:57 -04:00
Dan Saunders
2962a398b7 Lora kernels fix (#2732)
* fix lora kernel patching and improve test

* simplification
2025-05-28 10:03:43 -04:00
salman
65c5481120 Rank 0-only logging (#2608)
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-28 14:57:30 +01:00
salman
5fca214108 QAT (#2590)
QAT and quantization w/torchao
2025-05-28 12:35:47 +01:00
NanoCode012
20fda75917 feat(doc): add google analytics to docs (#2708) 2025-05-28 15:51:21 +07:00
NanoCode012
6b6370f4e3 feat(doc): add info on how to use dapo / dr grpo and misc doc fixes (#2673) [skip ci]
* feat(doc): add info on how to use dapo / dr grpo

* chore: add missing config to docs

* fix: missing comment

* fix: add missing scheduler from schema

* chore: refactor lr scheduler docs

* fix: remove log_sweep
2025-05-28 15:51:04 +07:00
mashdragon
add2025253 Fix Mistral chat template (mistral_v7_tekken) (#2710) [skip ci]
Per 4b8dd8aae7 (d2h-482763)
2025-05-28 15:50:47 +07:00
artem
a703560a10 add two checks to handle legacy format interleaved multimodal ds (#2721) [skip ci]
* add two checks to handle legacy format interleaved ds

* fix: add warning about multiple image using legacy format

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-28 15:49:43 +07:00
NOHHYEOB, BAE
4a80d309e8 Add chat templates for command-a and aya-23-8B models (#2731) [skip ci]
* Add chat templates for command-a and aya model

* Fix: isolate for-loop update and remove unintended changes
2025-05-28 15:49:16 +07:00
NanoCode012
e33f225434 feat(doc): note lora kernel incompat with RLHF (#2706) [skip ci]
* feat(doc): note lora kernel incompat with RLHF

* fix: add validation following comments

* chore: fix typo following suggestion
2025-05-28 15:48:40 +07:00
NanoCode012
3e6948be97 Fix(doc): clarify data loading for local datasets and splitting samples (#2726) [skip ci]
* fix(doc): remove incorrect json dataset loading method

* fix(doc): clarify splitting only happens in completion mode

* fix: update local file loading on config doc

* fix: typo
2025-05-28 15:48:22 +07:00
github-actions[bot]
4a8af60d34 chore: update pre-commit hooks (#2729)
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-05-27 11:45:31 -04:00
Dan Saunders
a0941a9271 no need to generate diff file (#2728) 2025-05-27 11:44:06 -04:00
Dan Saunders
5eb01f3df1 Fix quarto (#2717)
* missing modules

* fix quarto complaints
2025-05-23 21:16:51 -04:00
xzuyn
d27c35ac44 Liger GraniteMoE (#2715) 2025-05-23 18:40:43 -04:00
Dan Saunders
a535b68043 update quarto for model loading refactor (#2716)
* update quarto for model loading refactor

* fix desc
2025-05-23 16:28:31 -04:00
222 changed files with 5297 additions and 3453 deletions

View File

@@ -17,7 +17,7 @@ jobs:
build-base: build-base:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner runs-on: ubuntu-latest-m
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -28,42 +28,50 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "124" - cuda: "124"
cuda_version: 12.4.1 cuda_version: 12.4.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126" - cuda: "126"
cuda_version: 12.6.3 cuda_version: 12.6.3
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126" - cuda: "126"
cuda_version: 12.6.3 cuda_version: 12.6.3
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128" - cuda: "128"
cuda_version: 12.6.3 cuda_version: 12.6.3
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: nightly pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128" dockerfile: "Dockerfile-base-nightly"
cuda_version: 12.8.1 # # "next" is for release candidates of pytorch
cudnn_version: "" # - cuda: "128"
python_version: "3.11" # cuda_version: 12.8.1
pytorch: next # cudnn_version: ""
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" # python_version: "3.11"
# pytorch: next
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-next"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -85,7 +93,59 @@ jobs:
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }} file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDNN_VERSION=${{ matrix.cudnn_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
build-base-uv:
if: github.repository_owner == 'axolotl-ai-cloud'
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -9,6 +9,7 @@ on:
- '.github/workflows/*.yml' - '.github/workflows/*.yml'
- "*.[q]md" - "*.[q]md"
- "examples/**/*.y[a]?ml" - "examples/**/*.y[a]?ml"
- ".pre-commit-config.yaml"
workflow_dispatch: workflow_dispatch:
jobs: jobs:

View File

@@ -8,7 +8,7 @@ on:
- 'setup.py' - 'setup.py'
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml' - '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py' - 'src/axolotl/core/trainers/mixins/context_parallel.py'
- 'src/axolotl/utils/distributed.py' - 'src/axolotl/utils/distributed.py'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
@@ -59,7 +59,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==1.0.2 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -25,7 +25,6 @@ jobs:
pre-commit autoupdate pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi fi
- name: Create Pull Request - name: Create Pull Request
@@ -39,11 +38,3 @@ jobs:
commit-message: "chore: update pre-commit hooks" commit-message: "chore: update pre-commit hooks"
body: | body: |
Automated PR to update pre-commit hooks to their latest versions. Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -44,98 +44,6 @@ jobs:
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
# preload-cache:
# name: Preload HF cache
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python_version: ["3.11"]
# pytorch_version: ["2.6.0"]
# timeout-minutes: 20
#
# env:
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
#
# steps:
# - name: Check out repository code
# uses: actions/checkout@v4
#
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
#
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p /home/runner/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
#
# - name: Setup Python
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python_version }}
# cache: 'pip' # caching pip dependencies
#
# - name: upgrade pip
# run: |
# pip3 install --upgrade pip
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
#
# - name: Install PyTorch
# run: |
# pip3 install torch==${{ matrix.pytorch_version }}
#
# - name: Install dependencies
# run: |
# pip3 show torch
# pip3 install --no-build-isolation -U -e .
# python scripts/unsloth_install.py | sh
# python scripts/cutcrossentropy_install.py | sh
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
#
# - name: Make sure PyTorch version wasn't clobbered
# run: |
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
#
# - name: Ensure axolotl CLI was installed
# run: |
# axolotl --help
#
# - name: Pre-Download dataset fixture
# run: |
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
#
# - name: Run tests
# run: |
# pytest -v tests/conftest.py
#
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v5
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# files: ./coverage.xml
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
# fail_ci_if_error: false
#
# - name: cleanup pip cache
# run: |
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
#
# - name: Save HF cache
# id: hf-cache
# uses: actions/cache/save@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest: pytest:
name: PyTest name: PyTest
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -151,15 +59,6 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3 - name: Restore Cache from S3
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
@@ -222,7 +121,6 @@ jobs:
pytest-sdist: pytest-sdist:
name: PyTest from Source Dist name: PyTest from Source Dist
runs-on: ubuntu-latest runs-on: ubuntu-latest
# needs: [preload-cache]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -234,15 +132,6 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3 - name: Restore Cache from S3
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
@@ -312,6 +201,13 @@ jobs:
pytorch: 2.6.0 pytorch: 2.6.0
num_gpus: 1 num_gpus: 1
axolotl_extras: vllm axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -322,7 +218,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==1.0.2 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -333,6 +229,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests
@@ -384,7 +281,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==1.0.2 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -395,6 +292,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests
@@ -424,7 +322,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==1.0.2 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -19,15 +19,15 @@ repos:
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 7.1.2 rev: 7.2.0
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pylint-dev/pylint - repo: https://github.com/pylint-dev/pylint
rev: v3.3.6 rev: v3.3.7
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0 rev: v1.16.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:

View File

@@ -242,16 +242,12 @@
# early_stopping_patience: 3 # early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer # # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine # lr_scheduler: # 'one_cycle' | empty for cosine
# lr_scheduler_kwargs: # lr_scheduler_kwargs:
# # For one_cycle optim # # For one_cycle optim
# lr_div_factor: # Learning rate div factor # lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer # # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see: # # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 # # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134

View File

@@ -51,7 +51,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU - NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11 - Python 3.11
- PyTorch ≥2.4.1 - PyTorch ≥2.5.1
### Installation ### Installation

View File

@@ -17,7 +17,9 @@ quartodoc:
- convert - convert
- prompt_tokenizers - prompt_tokenizers
- logging_config - logging_config
- core.trainer_builder - core.builders.base
- core.builders.causal
- core.builders.rl
- core.training_args - core.training_args
- core.chat.messages - core.chat.messages
- core.chat.format.chatml - core.chat.format.chatml
@@ -43,6 +45,7 @@ quartodoc:
- cli.vllm_serve - cli.vllm_serve
- cli.cloud.base - cli.cloud.base
- cli.cloud.modal_ - cli.cloud.modal_
- cli.quantize
- title: Trainers - title: Trainers
desc: Training implementations desc: Training implementations
contents: contents:
@@ -54,6 +57,15 @@ quartodoc:
- core.trainers.grpo.trainer - core.trainers.grpo.trainer
- core.trainers.grpo.sampler - core.trainers.grpo.sampler
- core.trainers.utils - core.trainers.utils
- title: Model Loading
desc: Functionality for loading and patching models, tokenizers, etc.
contents:
- loaders.model
- loaders.tokenizer
- loaders.processor
- loaders.adapter
- loaders.patch_manager
- loaders.constants
- title: Mixins - title: Mixins
desc: Mixin classes for augmenting trainers desc: Mixin classes for augmenting trainers
contents: contents:
@@ -63,7 +75,7 @@ quartodoc:
- title: Context Managers - title: Context Managers
desc: Context managers for altering trainer behaviors desc: Context managers for altering trainer behaviors
contents: contents:
- utils.ctx_managers.sequence_parallel - utils.ctx_managers.context_parallel
- title: Prompt Strategies - title: Prompt Strategies
desc: Prompt formatting strategies desc: Prompt formatting strategies
contents: contents:
@@ -117,17 +129,16 @@ quartodoc:
- monkeypatch.trainer_fsdp_optim - monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils - monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_ - monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher - monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral - monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
- monkeypatch.gradient_checkpointing.offload_disk
- title: Utils - title: Utils
desc: Utility functions desc: Utility functions
contents: contents:
- utils.models
- utils.tokenization - utils.tokenization
- utils.chat_templates - utils.chat_templates
- utils.lora - utils.lora
- utils.lora_embeddings
- utils.model_shard_quant - utils.model_shard_quant
- utils.bench - utils.bench
- utils.freeze - utils.freeze
@@ -138,8 +149,7 @@ quartodoc:
- utils.optimizers.adopt - utils.optimizers.adopt
- utils.data.pretraining - utils.data.pretraining
- utils.data.sft - utils.data.sft
- utils.gradient_checkpointing.offload_cpu - utils.quantization
- utils.gradient_checkpointing.offload_disk
- title: Schemas - title: Schemas
desc: Pydantic data models for Axolotl config desc: Pydantic data models for Axolotl config
contents: contents:
@@ -189,12 +199,14 @@ quartodoc:
- utils.callbacks.lisa - utils.callbacks.lisa
- utils.callbacks.mlflow_ - utils.callbacks.mlflow_
- utils.callbacks.comet_ - utils.callbacks.comet_
- utils.callbacks.qat
website: website:
title: "Axolotl" title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun" description: "We make fine-tuning accessible, scalable, and fun"
favicon: favicon.jpg favicon: favicon.jpg
google-analytics: "G-9KYCVJBNMQ"
navbar: navbar:
logo: image/axolotl_logo_digital_white.svg logo: image/axolotl_logo_digital_white.svg
title: false title: false
@@ -247,6 +259,8 @@ website:
- docs/lr_groups.qmd - docs/lr_groups.qmd
- docs/lora_optims.qmd - docs/lora_optims.qmd
- docs/dataset_loading.qmd - docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- section: "Core Concepts" - section: "Core Concepts"
contents: contents:
@@ -260,7 +274,7 @@ website:
- docs/unsloth.qmd - docs/unsloth.qmd
- docs/torchao.qmd - docs/torchao.qmd
- docs/custom_integrations.qmd - docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd - docs/context_parallelism.qmd
- section: "Troubleshooting" - section: "Troubleshooting"
contents: contents:

52
cicd/Dockerfile-uv.jinja Normal file
View File

@@ -0,0 +1,52 @@
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py --uv | sh
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = { df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""), "AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"), "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"), "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "121"), "CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
@@ -55,7 +55,7 @@ VOLUME_CONFIG = {
} }
N_GPUS = int(os.environ.get("N_GPUS", 2)) N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS) GPU_CONFIG = f"H100:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str): def run_cmd(cmd: str, run_folder: str):

View File

@@ -8,8 +8,9 @@ import tempfile
import jinja2 import jinja2
import modal import modal
import modal.experimental
from jinja2 import select_autoescape from jinja2 import select_autoescape
from modal import App, Image from modal import App
cicd_path = pathlib.Path(__file__).parent.resolve() cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -17,14 +18,15 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment( template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape() loader=template_loader, autoescape=select_autoescape()
) )
df_template = template_env.get_template("Dockerfile.jinja") dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = { df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""), "AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"), "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"), "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "121"), "CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
@@ -38,11 +40,11 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f: with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents) f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile( cicd_image = modal.experimental.raw_dockerfile_image(
pathlib.Path(temp_dir) / "Dockerfile", pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None, # context_mount=None,
force_build=True, force_build=True,
gpu="A10G", # gpu="A10G",
).env(df_args) ).env(df_args)
app = App("Axolotl CI/CD", secrets=[]) app = App("Axolotl CI/CD", secrets=[])
@@ -55,7 +57,7 @@ VOLUME_CONFIG = {
} }
N_GPUS = int(os.environ.get("N_GPUS", 1)) N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS) GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str): def run_cmd(cmd: str, run_folder: str):

36
docker/Dockerfile-uv-base Normal file
View File

@@ -0,0 +1,36 @@
ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
ENV UV_TORCH_BACKEND="cu${CUDA}"
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& git lfs install --skip-repo \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:${PATH}"
RUN uv python install ${PYTHON_VERSION}
WORKDIR /workspace
RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic

View File

@@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing. This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
### quantize
Quantizes a model using the quantization configuration specified in your YAML file.
```bash
axolotl quantize config.yml
```
See [Quantization](./quantize.qmd) for more details.
## Legacy CLI Usage ## Legacy CLI Usage

View File

@@ -65,6 +65,20 @@ bnb_config_kwargs:
bnb_4bit_quant_type: nf4 bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true bnb_4bit_use_double_quant: true
# quantization aware training
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
# post-training quantization
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -98,8 +112,10 @@ plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin # - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -221,7 +237,7 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true shuffle_merged_datasets: true
Deduplicates datasets and test_datasets with identical entries. # Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with. # A list of one or more datasets to eval the model with.
@@ -270,10 +286,25 @@ trl:
num_generations: # Optional[int]. Number of generations to sample. num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions. log_completions: # Optional[bool]. Whether to log completions.
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
sync_ref_model: # Optional[bool]. Whether to sync the reference model. sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model. ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model. ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
top_k: # Optional[int]. Top-k sampling for the generation policy.
min_p: # Optional[float]. Minimum probability for the generation policy.
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
# reward modelling: `True` or `False` # reward modelling: `True` or `False`
@@ -483,6 +514,7 @@ output_dir: ./completed-model
# setting to `auto` will enable torch compile when torch>=2.5.1 # setting to `auto` will enable torch compile when torch>=2.5.1
torch_compile: # Optional[Union[Literal["auto"], bool]] torch_compile: # Optional[Union[Literal["auto"], bool]]
torch_compile_backend: # Optional[str] torch_compile_backend: # Optional[str]
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
# Training hyperparameters # Training hyperparameters
@@ -529,7 +561,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package) # Save model as safetensors (require safetensors package). Default True
save_safetensors: save_safetensors:
# Whether to mask out or include the human's prompt from the training labels # Whether to mask out or include the human's prompt from the training labels
@@ -551,7 +583,24 @@ gradient_checkpointing: false
early_stopping_patience: 3 early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer # Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine # Valid values are driven by the Transformers SchedulerType class, see:
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
# Valid values include
# - 'linear'
# - 'cosine' (default)
# - 'cosine_with_restarts'
# - 'polynomial'
# - 'constant'
# - 'constant_with_warmup'
# - 'inverse_sqrt'
# - 'reduce_lr_on_plateau'
# - 'cosine_with_min_lr'
# - 'warmup_stable_decay'
# Additional schedulers include:
# - 'one_cycle'
# - 'rex'
lr_scheduler:
lr_scheduler_kwargs: lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -569,7 +618,7 @@ lr_div_factor: # Learning rate div factor
# #
# Valid values for 'optimizer' include: # Valid values for 'optimizer' include:
# - adamw_torch # - adamw_torch
# - adamw_torch_fused # - adamw_torch_fused (default)
# - adamw_torch_xla # - adamw_torch_xla
# - adamw_torch_npu_fused # - adamw_torch_npu_fused
# - adamw_apex_fused # - adamw_apex_fused
@@ -715,13 +764,13 @@ ddp_timeout:
ddp_bucket_cap_mb: ddp_bucket_cap_mb:
ddp_broadcast_buffers: ddp_broadcast_buffers:
# Sequence parallelism # Context parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. # Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. # Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized # E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences. # subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details. # See https://docs.axolotl.ai/docs/context_parallelism.html for more details.
sequence_parallel_degree: context_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model. # Must evenly divide the number of KV heads in your model.
heads_k_stride: 1 heads_k_stride: 1

View File

@@ -36,10 +36,6 @@ It is typically recommended to save your dataset as `.jsonl` due to its flexibil
Axolotl supports loading from a Hugging Face hub repo or from local files. Axolotl supports loading from a Hugging Face hub repo or from local files.
::: {.callout-important}
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
:::
### Pre-training from Hugging Face hub datasets ### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config: As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
@@ -77,18 +73,21 @@ datasets:
type: completion type: completion
``` ```
From local files (either example works): From local files:
```yaml ```yaml
datasets: datasets:
- path: A.jsonl - path: A.jsonl
type: completion type: completion
- path: json - path: B.jsonl
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
type: completion type: completion
``` ```
::: {.callout-important}
For `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!
:::
### Pre-training dataset configuration tips ### Pre-training dataset configuration tips
#### Setting max_steps #### Setting max_steps

View File

@@ -54,7 +54,7 @@ datasets:
#### Files #### Files
Usually, to load a JSON file, you would do something like this: To load a JSON file, you would do something like this:
```python ```python
from datasets import load_dataset from datasets import load_dataset
@@ -66,20 +66,12 @@ Which translates to the following config:
```yaml ```yaml
datasets: datasets:
- path: json - path: data.json
data_files: /path/to/your/file.jsonl
```
However, to make things easier, we have added a few shortcuts for loading local dataset files.
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
```yaml
datasets:
- path: /path/to/your/file.jsonl
ds_type: json ds_type: json
``` ```
In the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.
This works for CSV, JSON, Parquet, and Arrow files. This works for CSV, JSON, Parquet, and Arrow files.
::: {.callout-tip} ::: {.callout-tip}

View File

@@ -36,7 +36,6 @@ Tags examples:
- `main-base-py3.11-cu126-2.7.0` - `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0` - `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1` - `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main ## Main
@@ -77,12 +76,10 @@ Tags examples:
- `main-py3.11-cu126-2.7.0` - `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0` - `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1` - `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest` - `main-latest`
- `main-20250303-py3.11-cu124-2.6.0` - `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1` - `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1` - `0.9.2`
- `0.7.1`
## Cloud ## Cloud

View File

@@ -110,3 +110,17 @@ description: Frequently asked questions
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable. > A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens. > Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
**Q: `Data processing error: CAS service error`**
> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`
**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**
> A: Depending on the version of torch, you may need to include this in your YAML:
> ```yaml
> flex_attn_compile_kwargs:
> dynamic: false
> mode: max-autotune-no-cudagraphs
> ```

View File

@@ -180,7 +180,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics: Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options - [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources - [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats - [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd) - [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd) - [Multi-Node Training](multi-node.qmd)

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU - NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10 - Python ≥3.10
- PyTorch ≥2.4.1 - PyTorch ≥2.5.1
## Installation Methods {#sec-installation-methods} ## Installation Methods {#sec-installation-methods}
@@ -41,6 +41,40 @@ installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed dependencies that are specific to the PyTorch version or other installed
co-dependencies. co-dependencies.
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
Install uv if not already installed
```{.bash}
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
then create the venv and activate
```{.bash}
export UV_TORCH_BACKEND=cu126
uv venv --no-project --relocatable
source .venv/bin/activate
```
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
```
### Edge/Development Build {#sec-edge-build} ### Edge/Development Build {#sec-edge-build}
For the latest features between releases: For the latest features between releases:

View File

@@ -84,6 +84,10 @@ lora_qkv_kernel: true
lora_o_kernel: true lora_o_kernel: true
``` ```
::: {.callout-note}
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
## Requirements ## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -18,7 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended) - DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel) - FSDP (Fully Sharded Data Parallel)
- Sequence parallelism - Context parallelism
- FSDP + QLoRA - FSDP + QLoRA
## DeepSpeed {#sec-deepspeed} ## DeepSpeed {#sec-deepspeed}
@@ -80,14 +80,14 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
``` ```
## Sequence parallelism {#sec-sequence-parallelism} ## Context parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the We support context parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This [ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training. single sequence causes OOM errors during model training.
See our [dedicated guide](sequence_parallelism.qmd) for more information. See our [dedicated guide](context_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora} ### FSDP + QLoRA {#sec-fsdp-qlora}

View File

@@ -43,7 +43,7 @@ datasets:
# leave the vision model and vision tower frozen # leave the vision model and vision tower frozen
# load_in_8bit: true # load_in_8bit: true
adapter: lora adapter: lora
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# (optional) if you want to resize images to a set size # (optional) if you want to resize images to a set size
image_size: 512 image_size: 512

32
docs/qat.qmd Normal file
View File

@@ -0,0 +1,32 @@
---
title: "Quantization Aware Training (QAT)"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
support for QAT and post-training quantization (PTQ) in axolotl.
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
## Configuring QAT in Axolotl
To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.

53
docs/quantize.qmd Normal file
View File

@@ -0,0 +1,53 @@
---
title: "Quantization with torchao"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
::: {.callout-note}
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
:::
## Configuring Quantization in Axolotl
Quantization is configured using the `quantization` key in your configuration file.
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
output_dir: # The path to the output directory.
```
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
you used to train the model:
```yaml
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```
```bash
axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.

View File

@@ -16,7 +16,8 @@ feedback. Various methods include, but not limited to:
- [Identity Preference Optimization (IPO)](#ipo) - [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto) - [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Odds Ratio Preference Optimization (ORPO)](#orpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl) - [Group Relative Policy Optimization (GRPO)](#grpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
## RLHF using Axolotl ## RLHF using Axolotl
@@ -582,7 +583,20 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py). To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
```yaml
trl:
loss_type: dr_grpo
# Normalizes loss based on max completion length (default: 256)
max_completion_length:
```
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### SimPO ### SimPO

View File

@@ -1,16 +1,16 @@
--- ---
title: Sequence Parallelism title: Context Parallelism
description: Train with long sequences split across multiple GPUs. description: Train with long sequences split across multiple GPUs.
--- ---
Sequence parallelism is a technique that splits sequences across multiple GPUs, Context parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern. through a ring communication pattern.
## When to Use Sequence Parallelism ## When to Use Context Parallelism
Use sequence parallelism when: Use context parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory - You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available - You have multiple GPUs available
@@ -18,11 +18,11 @@ Use sequence parallelism when:
## Configuration ## Configuration
To enable sequence parallelism, add the following to your configuration file: To enable context parallelism, add the following to your configuration file:
```yaml ```yaml
# Set to a divisor (> 1) of the number of GPUs available # Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs context_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to # Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,23 +30,23 @@ heads_k_stride: 1
ring_attn_func: ring_attn_func:
``` ```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8 - With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4 - With 4 GPUs, valid values would be 2 or 4
## Implementation Details ## Implementation Details
When sequence parallelism is enabled: When context parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group 1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids 2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions 3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations 4. The trainer uses special ring communication patterns for attention operations
## Requirements ## Requirements
To use sequence parallelism, you need: To use context parallelism, you need:
- Multiple GPUs (at least 2) - Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with: - The `ring-flash-attn` package. Install with:
@@ -66,7 +66,7 @@ sequence_len: 8192
... ...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to # Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -79,22 +79,22 @@ ring_attn_func:
This will train the Llama 3 8B model with 8K context length, with each sequence split This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs. into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism ## Sample Packing with Context Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together: Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together 1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group 2. The packed sequences are then divided across GPUs in the context parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions 3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size ## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) - Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases - The number of batches processed per step decreases
For example: For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and no context parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) - With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -28,7 +28,7 @@ pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -30,7 +30,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -29,7 +29,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -0,0 +1,79 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./outputs/qat_out/
sample_packing: true
pad_to_sequence_len: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/lora-out output_dir: ./outputs/lora-out
@@ -38,6 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 1 num_epochs: 1
optimizer: adamw_8bit optimizer: adamw_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -27,7 +27,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -0,0 +1,78 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out/
sequence_len: 2048
sample_packing: true
flex_attention: true
pad_to_sequence_len: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 256
fake_quant_after_n_steps: 1000
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
max_steps: 2000
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:

View File

@@ -6,21 +6,20 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.9 liger-kernel==0.5.10
# END section # END section
packaging==23.2 packaging==23.2
huggingface_hub==0.31.0 huggingface_hub==0.32.2
peft==0.15.2 peft==0.15.2
transformers==4.51.3 transformers==4.52.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.7.0
datasets==3.5.1 datasets==3.6.0
deepspeed>=0.15.4 deepspeed>=0.17.0
trl==0.17.0 trl==0.18.1
hf_xet==1.1.0 hf_xet==1.1.2
hqq==0.2.5
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
@@ -63,7 +62,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.9.0 torchao==0.10.0
schedulefree==1.4.1 schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.6

View File

@@ -9,6 +9,8 @@ except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V from packaging.version import Version as V
USE_UV = "--uv" in sys.argv[1:]
v = V(torch.__version__) v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0 # no cut-cross-entropy support for torch < 2.4.0
@@ -23,7 +25,9 @@ if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"): if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
) )

View File

@@ -11,7 +11,7 @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@ @@@@ @@@@@@@@@@@@@@@@
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands: Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
``` ```
cd /workspace cd /workspace

View File

@@ -1,11 +1,15 @@
# noqa # noqa
# pylint: skip-file # pylint: skip-file
import sys
try: try:
import torch import torch
except ImportError: except ImportError:
raise ImportError("Install torch via `pip install torch`") raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]
v = V(torch.__version__) v = V(torch.__version__)
cuda = str(torch.version.cuda) cuda = str(torch.version.cuda)
try: try:
@@ -31,6 +35,7 @@ elif v < V("2.6.0"):
else: else:
raise RuntimeError(f"Torch = {v} too new!") raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print( print(
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"' f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
) )

View File

@@ -118,7 +118,7 @@ extras_require = {
"yunchang==0.6.0", "yunchang==0.6.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.15.4", "deepspeed==0.17.0",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -28,7 +28,6 @@ class TrainerCliArgs:
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0) debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None) main_process_port: Optional[int] = field(default=None)
@@ -89,6 +88,26 @@ class VllmServeCliArgs:
}, },
) )
enable_reasoning: Optional[bool] = field(
default=None,
)
reasoning_parser: Optional[str] = field(
default=None,
)
@dataclass
class QuantizeCliArgs:
"""Dataclass with CLI arguments for `axolotl quantize` command."""
base_model: Optional[str] = field(default=None)
weight_dtype: Optional[str] = field(default=None)
activation_dtype: Optional[str] = field(default=None)
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
@dataclass @dataclass
class EvaluateCliArgs: class EvaluateCliArgs:

View File

@@ -1,6 +1,5 @@
"""Various checks for Axolotl CLI.""" """Various checks for Axolotl CLI."""
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -8,7 +7,9 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_accelerate_default_config() -> None: def check_accelerate_default_config() -> None:

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res return res
def get_image(self): def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1" docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag: if self.config.docker_tag:
docker_tag = self.config.docker_tag docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}" docker_image = f"axolotlai/axolotl:{docker_tag}"

View File

@@ -1,7 +1,6 @@
"""Configuration loading and processing.""" """Configuration loading and processing."""
import json import json
import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -22,11 +21,12 @@ from axolotl.utils.config import (
validate_config, validate_config,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__) LOG = get_logger(__name__, use_environ=True)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
) )
if len(yaml_files) == 1: if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'") LOG.info(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0]) return str(yaml_files[0])
print("Choose a YAML file:") LOG.info("Choose a YAML file:")
for idx, file in enumerate(yaml_files): for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}") LOG.info(f"{idx + 1}. {file}")
chosen_file = None chosen_file = None
while chosen_file is None: while chosen_file is None:
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
if 1 <= choice <= len(yaml_files): if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1]) chosen_file = str(yaml_files[choice - 1])
else: else:
print("Invalid choice. Please choose a number from the list.") LOG.info("Invalid choice. Please choose a number from the list.")
except ValueError: except ValueError:
print("Invalid input. Please enter a number.") LOG.info("Invalid input. Please enter a number.")
return chosen_file return chosen_file

View File

@@ -1,6 +1,5 @@
"""CLI to run evaluation on a model.""" """CLI to run evaluation on a model."""
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -17,8 +16,9 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:

View File

@@ -1,7 +1,6 @@
"""CLI to run inference on a trained model.""" """CLI to run inference on a trained model."""
import importlib import importlib
import logging
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
@@ -22,8 +21,9 @@ from axolotl.utils.chat_templates import (
get_chat_template_from_config, get_chat_template_from_config,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def get_multi_line_input() -> str: def get_multi_line_input() -> str:

View File

@@ -2,7 +2,6 @@
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import logging
import os import os
import subprocess # nosec B404 import subprocess # nosec B404
import tempfile import tempfile
@@ -17,6 +16,7 @@ import axolotl
from axolotl.cli.args import ( from axolotl.cli.args import (
EvaluateCliArgs, EvaluateCliArgs,
PreprocessCliArgs, PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs, TrainerCliArgs,
VllmServeCliArgs, VllmServeCliArgs,
) )
@@ -30,8 +30,11 @@ from axolotl.cli.utils import (
) )
from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env from axolotl.utils import patch_optimized_env
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
@click.group() @click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl") @click.version_option(version=axolotl.__version__, prog_name="axolotl")
@@ -176,7 +179,7 @@ def train(
do_cli(config=cfg_file, **kwargs) do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc: except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep: if not sweep:
raise exc raise exc
@@ -333,6 +336,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args) do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize
do_quantize(config, cli_args)
@cli.command() @cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str)) @click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str)) @click.argument("output", type=click.Path(exists=False, path_type=str))

View File

@@ -1,20 +1,18 @@
"""CLI to merge a trained LoRA into a base model.""" """CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import fire import fire
import transformers
from dotenv import load_dotenv from dotenv import load_dotenv
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def do_merge_lora(*, cfg: DictDefault) -> None: def do_merge_lora(*, cfg: DictDefault) -> None:
@@ -68,12 +66,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Raises: Raises:
ValueError: If target directory for LoRA merged model does not exist. ValueError: If target directory for LoRA merged model does not exist.
""" """
# pylint: disable=duplicate-code
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg( parsed_cfg = load_cfg(
config, config,
@@ -81,7 +73,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
sequence_parallel_degree=None, context_parallel_degree=None,
deepspeed=None, deepspeed=None,
fsdp=None, fsdp=None,
fsdp_config=None, fsdp_config=None,

View File

@@ -1,7 +1,6 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" """CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json import json
import logging
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -11,7 +10,6 @@ import fire
import torch import torch
import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
import transformers
from accelerate.utils import ( from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
@@ -24,11 +22,11 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
@@ -197,11 +195,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -1,6 +1,5 @@
"""CLI to run preprocessing of a dataset.""" """CLI to run preprocessing of a dataset."""
import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -20,9 +19,10 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:

View File

@@ -0,0 +1,90 @@
"""
CLI to post-training quantize a model using torchao
"""
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = get_logger(__name__)
def do_quantize(
config: Union[Path, str],
cli_args: dict,
):
"""
Quantizes a model's model's weights
Args:
config (Union[Path, str]): The path to the config file
cli_args (dict): Additional command-line arguments
"""
print_axolotl_text_art()
cfg = load_cfg(config)
if cfg.qat and cfg.quantization:
raise ValueError(
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
)
if cfg.qat:
quantize_cfg = cfg.qat
elif cfg.quantization:
quantize_cfg = cfg.quantization
else:
raise ValueError(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
quantize_embedding = (
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
f"\tweight_dtype: {weight_dtype}\n"
f"\tactivation_dtype: {activation_dtype}\n"
f"\tgroup_size: {group_size}\n"
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model.""" """CLI to run training on a model."""
import gc import gc
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -22,8 +21,6 @@ from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
""" """

View File

@@ -4,7 +4,6 @@ import concurrent.futures
import dataclasses import dataclasses
import hashlib import hashlib
import json import json
import logging
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -23,8 +22,9 @@ from transformers import (
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None): def strip_optional_type(field_type: type | str | None):

View File

@@ -2,14 +2,27 @@
CLI to start the vllm server for online RL CLI to start the vllm server for online RL
""" """
import os
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
@dataclass
class AxolotlScriptArguments(ScriptArguments):
"""
Additional arguments for the VLLM server
"""
reasoning_parser: str = field(default="", kw_only=True)
enable_reasoning: bool | None = field(default=None, kw_only=True)
def do_vllm_serve( def do_vllm_serve(
config: Union[Path, str], config: Union[Path, str],
cli_args: dict, cli_args: dict,
@@ -24,6 +37,7 @@ def do_vllm_serve(
Returns: Returns:
process_id: the process id of the started VLLM server process_id: the process id of the started VLLM server
""" """
patch_vllm_worker()
cfg = load_cfg(config) cfg = load_cfg(config)
model = cfg.base_model model = cfg.base_model
@@ -43,9 +57,16 @@ def do_vllm_serve(
enable_prefix_caching = ( enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
) )
reasoning_parser = (
cli_args.get("reasoning_parser") or cfg.vllm.reasoning_parser or ""
)
enable_reasoning = (
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
vllm_script_args = ScriptArguments( # pylint: disable=unexpected-keyword-arg
model, vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
host=host, host=host,
port=port, port=port,
@@ -53,5 +74,67 @@ def do_vllm_serve(
dtype=dtype, dtype=dtype,
max_model_len=max_model_len, max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching, enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
) )
vllm_serve_main(vllm_script_args) vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker

View File

@@ -1,6 +1,5 @@
"""Dataset loading utilities.""" """Dataset loading utilities."""
import logging
import math import math
import random import random
from dataclasses import dataclass from dataclasses import dataclass
@@ -14,10 +13,11 @@ from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
@dataclass @dataclass

View File

@@ -0,0 +1,6 @@
"""Trainer builder classes"""
from .causal import HFCausalTrainerBuilder
from .rl import HFRLTrainerBuilder
__all__ = ["HFCausalTrainerBuilder", "HFRLTrainerBuilder"]

View File

@@ -0,0 +1,503 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for trainer builder"""
import abc
import importlib
import logging
import sys
from abc import abstractmethod
from contextlib import suppress
from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
self._train_dataset = None
self._eval_dataset = None
self._model_ref = None
self._peft_config = None
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@property
def peft_config(self):
return self._peft_config
@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config
@abstractmethod
def build(self, total_num_steps):
pass
def get_callbacks(self) -> list[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
callbacks.append(GPUStatsCallback(cfg=self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_ratio = self.cfg.warmup_ratio
elif total_num_steps:
warmup_steps = min(int(0.03 * total_num_steps), 100)
else:
warmup_ratio = 0.03
if warmup_steps == 1:
warmup_steps = 2
if self.cfg.logging_steps is not None:
training_args_kwargs["logging_steps"] = self.cfg.logging_steps
else:
training_args_kwargs["logging_steps"] = (
500 # transformers defaults to 500
if not total_num_steps
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine"
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
else:
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
def _configure_optimizer(self, training_args_kwargs: dict, trainer_kwargs: dict):
def _configure_custom_optimizer(
training_args_kwargs: dict, trainer_kwargs: dict
):
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_args_kwargs["learning_rate"],
"weight_decay": training_args_kwargs["weight_decay"],
}
# Adam-specific kwargs
adam_kwargs: dict = {}
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
"adam_beta2"
):
adam_kwargs["betas"] = (
training_args_kwargs.get("adam_beta1"),
training_args_kwargs.get("adam_beta2"),
)
if training_args_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
beta3 = training_args_kwargs.get("adam_beta3", 0.9999)
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
# Note: This is not used in training_args_kwargs, but in trainer_kwargs
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
_configure_custom_optimizer(training_args_kwargs, trainer_kwargs)
else:
# Use transformers' optimizer
training_args_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_args_kwargs["optim_args"] = optim_args
if (
self.cfg.optimizer == "adamw_anyprecision"
and Path(self.cfg.torchdistx_path).exists()
):
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
def _configure_hub_parameters(self, training_args_kwargs: dict):
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
# eval_strategy and eval_steps
if not self.eval_dataset or self.cfg.val_set_size == 0:
# do not eval if no eval_dataset or val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
def _configure_reporting(self, training_args_kwargs: dict):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_args_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_args_kwargs["run_name"] = None
def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
def _set_base_training_args(
self, total_num_steps
) -> tuple[dict[str, Any], dict[str, Any]]:
training_args_kwargs: dict[str, Any] = {}
trainer_kwargs: dict[str, Any] = {}
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)
self._configure_precision_settings(training_args_kwargs)
self._configure_save_and_eval_strategy(training_args_kwargs)
self._configure_gradient_checkpointing(training_args_kwargs)
# set arg into trainer_args_kwargs with same name if value not None
for arg in [
# optim/scheduler
"adam_beta1",
"adam_beta2",
"adam_beta3",
"adam_epsilon",
"adam_epsilon2",
"cosine_min_lr_ratio",
"cosine_constant_lr_ratio",
"optim_target_modules",
# trainer
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"gradient_accumulation_steps",
"learning_rate",
"embedding_lr",
"embedding_lr_scale",
"lr_groups",
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
self._configure_reporting(training_args_kwargs)
self._configure_hub_parameters(training_args_kwargs)
self._configure_scheduler(training_args_kwargs)
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
self._configure_torch_compile(training_args_kwargs)
return training_args_kwargs, trainer_kwargs

View File

@@ -0,0 +1,489 @@
"""Builder for causal trainers"""
import inspect
import math
import os
from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlMambaTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(EvalFirstStepCallback())
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
# TODO: check if can move to base class
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = (
self.cfg.lr_quadratic_warmup
)
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs["dataloader_drop_last"] = (
self.cfg.dataloader_drop_last
)
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = (
self.cfg.metric_for_best_model
)
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = (
self.cfg.ddp_broadcast_buffers
)
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs["auto_find_batch_size"] = (
self.cfg.auto_find_batch_size
)
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not self.cfg.flash_attention
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = (
self.cfg.lisa_step_interval
)
training_arguments_kwargs["lisa_layers_attribute"] = (
self.cfg.lisa_layers_attribute
)
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs["neftune_noise_alpha"] = (
self.cfg.neftune_noise_alpha
)
if self.cfg.accelerator_config:
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
multiple = 64
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
use_batch_sampler_collator = False
if is_eval is False and training_args.sample_packing:
use_batch_sampler_collator = True
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True
collator: Type[
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
]
]
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama
if (
self.cfg.flex_attention
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
)
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
return collator(
*collator_args,
**kwargs,
)

View File

@@ -0,0 +1,246 @@
"""Builder for RLHF trainers"""
import inspect
from pathlib import Path
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
LOG = get_logger(__name__)
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def _get_trainer_cls(self, trainer_kwargs: dict):
"""
Returns trainer_cls and trainer_cls_args
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None:
return trainer_cls, trainer_cls_args
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
context_parallel=self.cfg.context_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args.append(self.model_ref)
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):
"""
Returns training_args and trainer_kwargs
"""
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
# only rlhf
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
# Not compatible with IPO
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args, trainer_kwargs
def build(self, total_num_steps):
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# TODO: build PPOConfig
raise NotImplementedError("PPO trainer builder is not implemented yet.")

View File

@@ -156,7 +156,6 @@ class Messages(BaseModel):
len(input_ids) : len(input_ids) + len(pending_input_ids) len(input_ids) : len(input_ids) + len(pending_input_ids)
] ]
if new_pending_inputs != pending_input_ids: if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids) input_ids.extend(pending_input_ids)
if pending_weight: if pending_weight:

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@
from .base import AxolotlTrainer from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer from .grpo.trainer import AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer from .relora import ReLoRATrainer
from .trl import ( from .trl import (

View File

@@ -4,15 +4,16 @@
from __future__ import annotations from __future__ import annotations
import logging
import os import os
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import partial, wraps
from typing import Literal from typing import Any, Callable, Literal, Optional
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
import datasets import datasets
import torch import torch
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.utils.data import ( from torch.utils.data import (
BatchSampler, BatchSampler,
DataLoader, DataLoader,
@@ -34,9 +35,10 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
) )
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
@@ -65,6 +67,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# SPDA device mesh init
import torch.distributed as dist
world_size = dist.get_world_size()
mesh_shape = (
world_size // 2,
2,
)
self.world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
) -> torch.Tensor:
ctx_manager = get_context_parallel_manager(
world_mesh=self.world_mesh,
model=model,
)
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
with ctx_manager(list(to_shard.values())):
super().training_step(model, inputs, num_items_in_batch)
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile: if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -113,7 +141,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
drop_last=True, drop_last=True,
) )
def _get_train_sampler(self) -> Sampler | None: def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
""" """
Helper method to get the sampler for training. Handles cases for sample packing Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential). and curriculum sampling (sequential).
@@ -137,7 +167,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
if use_sample_packing: if use_sample_packing:
return self._create_multipack_sampler( return self._create_multipack_sampler(
base_sampler=base_sampler, base_sampler=base_sampler,
dataset=self.train_dataset, dataset=train_dataset,
) )
return base_sampler return base_sampler
@@ -150,8 +180,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
If the dataset is non-empty, a sampler is returned, the type of which If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args. depends on the passed training args.
""" """
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled # Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = ( use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False self.args.sample_packing and self.args.eval_sample_packing is not False
@@ -172,125 +200,91 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
return base_sampler return base_sampler
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None): def _get_dataloader(
"""Create common dataloader parameters for train or eval.""" self,
batch_size = custom_batch_size or ( dataset: Dataset,
self.args.eval_batch_size if is_eval else self._train_batch_size description: str,
) batch_size: int,
sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
is_training: bool = False,
dataloader_key: Optional[str] = None,
) -> DataLoader:
"""Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
params = { data_collator = self.data_collator if is_training else self.eval_data_collator
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if isinstance(dataset, datasets.Dataset):
if is_training:
if not self.args.sample_packing or self.args.pretraining:
dataset = self._remove_unused_columns(
dataset, description="training"
)
elif (
not is_training
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
batch_size = (
batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
else:
dataset = self._remove_unused_columns(dataset, description=description)
else:
data_collator = self._get_collator_with_removed_columns(
self.data_collator, description=description
)
dataloader_params = {
"batch_size": batch_size, "batch_size": batch_size,
"collate_fn": self.data_collator, "collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers, "num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
} }
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset): if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler): dataloader_params["drop_last"] = self.args.dataloader_drop_last
# batch_size and batch_sampler are mutually exclusive if sampler_fn is not None:
dataloader_params["batch_sampler"] = sampler sampler = sampler_fn(dataset)
del dataloader_params["batch_size"] if isinstance(sampler, BatchSampler):
else: # batch_size and batch_sampler are mutually exclusive
dataloader_params["sampler"] = sampler dataloader_params["batch_sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last del dataloader_params["batch_size"]
del dataloader_params["drop_last"]
if not is_eval: else:
dataloader_params["worker_init_fn"] = seed_worker dataloader_params["sampler"] = sampler
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if is_training:
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if self.args.sample_packing and ( if self.args.sample_packing and (
(not is_eval and not self.args.pretraining) (is_training and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False) or (not is_training and self.args.eval_sample_packing is not False)
): ):
self.accelerator.even_batches = False self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(dataloader) dataloader = DataLoader(dataset, **dataloader_params)
def get_train_dataloader(self) -> DataLoader: # Accelerator.free_memory() will destroy the references, so
"""Get dataloader for training""" # we need to store the non-prepared version for eval dataloaders.
train_dataset = self.train_dataset # fmt: off
data_collator = self.data_collator # type: ignore if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
# fmt: on
# Handle dataset preprocessing return self.accelerator.prepare(dataloader)
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
return dataloader
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler( def _get_bench_sampler(
self, bench_dataset: Dataset self, bench_dataset: Dataset

View File

@@ -5,65 +5,31 @@ from functools import wraps
from typing import Any, Dict, Union from typing import Any, Dict, Union
import torch import torch
from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import ( from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
) )
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
):
"""Extend the base DPOTrainer for axolotl helpers.""" """Extend the base DPOTrainer for axolotl helpers."""
tag_names = ["axolotl", "dpo"] tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs): def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags self.dataset_tags = dataset_tags
self.optimizer = None self.optimizer = None
self.model_accepts_loss_kwargs = False self.model_accepts_loss_kwargs = False
def create_optimizer(self):
# pylint: disable=duplicate-code
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
# pylint: disable=duplicate-code
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub) @wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str: def push_to_hub(self, *args, **kwargs) -> str:
""" """

View File

@@ -2,20 +2,20 @@
import importlib import importlib
import inspect import inspect
import logging
from typing import Any from typing import Any
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import ( from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOContextParallelTrainer,
AxolotlGRPOTrainer, AxolotlGRPOTrainer,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class GRPOStrategy: class GRPOStrategy:
@@ -23,10 +23,10 @@ class GRPOStrategy:
@classmethod @classmethod
def get_trainer_class( def get_trainer_class(
cls, sequence_parallel: bool cls, context_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]: ) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOContextParallelTrainer]:
if sequence_parallel: if context_parallel:
return AxolotlGRPOSequenceParallelTrainer return AxolotlGRPOContextParallelTrainer
return AxolotlGRPOTrainer return AxolotlGRPOTrainer
@classmethod @classmethod
@@ -69,6 +69,9 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.context_parallel_degree > 1:
grpo_args_kwargs["context_parallel_degree"] = cfg.context_parallel_degree
if trl.reward_weights: if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights grpo_args_kwargs["reward_weights"] = trl.reward_weights
@@ -106,7 +109,9 @@ class GRPOStrategy:
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]: def set_trainer_args(
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
trainer_args = [] trainer_args = []
if cfg.trl and cfg.trl.reward_funcs: if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = [] reward_funcs = []
@@ -123,6 +128,7 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = ( trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes cfg.trl.reward_processing_classes
) )
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod
@@ -132,7 +138,7 @@ class GRPOStrategy:
@classmethod @classmethod
def get_blocklist_args_kwargs(cls) -> list[str]: def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc"] return ["dataset_num_proc", "max_length"]
@classmethod @classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
@@ -167,4 +173,4 @@ class GRPOStrategy:
LOG.info( LOG.info(
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
) )
return reward_func return reward_func_fqn

View File

@@ -12,3 +12,5 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass @dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training""" """Axolotl GRPO Config for GRPO training"""
context_parallel_degree: int | None = None

View File

@@ -1,7 +1,7 @@
"""Repeat random sampler (similar to the one implemented in """Repeat random sampler (similar to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same context parallelism functionality; i.e., duplicating data across ranks in the same
sequence parallel group. context parallel group.
""" """
from typing import Iterator, Sized from typing import Iterator, Sized
@@ -10,26 +10,26 @@ import torch
from torch.utils.data import Sampler from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler): class ContextParallelRepeatRandomSampler(Sampler):
"""Sampler for GRPO training with sequence parallelism. """Sampler for GRPO training with context parallelism.
This sampler ensures: This sampler ensures:
- Ranks in the same sequence parallel (SP) group receive identical data. - Ranks in the same context parallel (SP) group receive identical data.
- Each index is repeated multiple times for sampling different completions. - Each index is repeated multiple times for sampling different completions.
- Entire batches are repeated for reuse in multiple updates. - Entire batches are repeated for reuse in multiple updates.
- Data is properly distributed across SP groups. - Data is properly distributed across CP groups.
In the table below, the values represent dataset indices. Each SP group has In the table below, the values represent dataset indices. Each CP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 `context_parallel_degree = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs. CP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups Context Parallel Groups
| SP0 | SP1 | | SP0 | SP1 |
| GPU 0 | GPU 1 | GPU 2 | GPU 3 | | GPU 0 | GPU 1 | GPU 2 | GPU 3 |
global_step step <---> mini_repeat_count=3 global_step step <---> mini_repeat_count=3
<----------> batch_size=2 per SP group <----------> batch_size=2 per CP group
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- CP groups get different data
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU ▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each CP group GPU
| |
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations | 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process. rank: Rank of current process.
batch_size: Number of samples per batch. batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process. repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group. context_parallel_degree: Number of ranks in a context parallel group.
shuffle: Whether to shuffle the dataset. shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling. seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch. drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int, rank: int,
batch_size: int = 1, batch_size: int = 1,
repeat_count: int = 1, repeat_count: int = 1,
sequence_parallel_degree: int = 1, context_parallel_degree: int = 1,
shuffle: bool = True, shuffle: bool = True,
seed: int = 0, seed: int = 0,
drop_last: bool = False, drop_last: bool = False,
@@ -76,16 +76,16 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.world_size = world_size self.world_size = world_size
self.rank = rank self.rank = rank
# Sequence parallelism parameters # Context parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree self.context_parallel_degree = context_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree self.num_sp_groups = world_size // context_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree self.sp_group_id = rank // context_parallel_degree
# Adjust dataset size for distributed sampling # Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset) self.num_samples = len(self.dataset)
self.total_size = self.num_samples self.total_size = self.num_samples
# Calculate effective number of samples per SP group # Calculate effective number of samples per CP group
if ( if (
self.drop_last self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0 and self.total_size % (self.num_sp_groups * self.batch_size) != 0
@@ -125,8 +125,8 @@ class SequenceParallelRepeatRandomSampler(Sampler):
padding = indices[: self.batch_size - len(indices) % self.batch_size] padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding indices += padding
# Subsample based on SP group ID # Subsample based on CP group ID
# Each SP group gets distinct batches of data # Each CP group gets distinct batches of data
batch_indices = [] batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups): for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size start_idx = i + self.sp_group_id * self.batch_size

View File

@@ -1,4 +1,4 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)""" """Axolotl GRPO trainers (with and without context parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
@@ -41,8 +41,9 @@ from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.grpo.sampler import ContextParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available(): if is_peft_available():
@@ -50,14 +51,16 @@ if is_peft_available():
from peft import PeftConfig from peft import PeftConfig
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
):
"""Extend the base GRPOTrainer for axolotl helpers""" """Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling""" """Extend the base GRPOTrainer for context parallelism handling"""
def __init__( def __init__(
self, self,
@@ -77,6 +80,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None), ] = (None, None),
peft_config: "PeftConfig | None" = None, peft_config: "PeftConfig | None" = None,
optimizer_cls_and_kwargs: tuple[type, dict] | None = None,
): ):
# First call the superclass constructor with all arguments # First call the superclass constructor with all arguments
super().__init__( super().__init__(
@@ -90,13 +94,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
callbacks=callbacks, callbacks=callbacks,
optimizers=optimizers, optimizers=optimizers,
peft_config=peft_config, peft_config=peft_config,
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
) )
# Get number of SP groups (number of processes divided by SP degree) # Get number of CP groups (number of processes divided by CP degree)
num_processes = self.accelerator.num_processes num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree num_sp_groups = num_processes // self.args.context_parallel_degree
# Calculate batch size per SP group (not per process) # Calculate batch size per CP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
possible_values = [ possible_values = [
n_gen n_gen
@@ -106,7 +111,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values: if self.num_generations not in possible_values:
raise ValueError( raise ValueError(
f"The batch size per SP group ({num_sp_groups} x " f"The batch size per CP group ({num_sp_groups} x "
f"{self.args.per_device_train_batch_size}) must be evenly divisible by " f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
f"the number of generations per prompt ({self.num_generations}). Given " f"the number of generations per prompt ({self.num_generations}). Given "
"the current configuration, the valid values for the number of " "the current configuration, the valid values for the number of "
@@ -114,7 +119,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
) )
if self.args.eval_strategy != "no": if self.args.eval_strategy != "no":
# If sequence parallelism is enabled, calculate batch size per SP group # If context parallelism is enabled, calculate batch size per CP group
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr] sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
possible_values = [ possible_values = [
n_gen n_gen
@@ -124,20 +129,29 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values: if self.num_generations not in possible_values:
raise ValueError( raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " f"With context parallelism (degree {self.args.context_parallel_degree}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " f"the eval batch size per CP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt " f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, " f"({self.num_generations}). Given the current eval batch size, "
f"the valid values for the number of generations are: {possible_values}." f"the valid values for the number of generations are: {possible_values}."
) )
# Initialize the SP group self.sp_group = None
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.local_rank = 0
self.local_world_size = 1
def train(self, *args, **kwargs):
# Initialize the CP group
self.sp_group = get_ring_attn_group() self.sp_group = get_ring_attn_group()
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
self.local_rank = dist.get_rank(group=self.sp_group) self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group) self.local_world_size = dist.get_world_size(group=self.sp_group)
return super().train(*args, **kwargs)
def _get_train_sampler(self) -> Sampler: def _get_train_sampler(self) -> Sampler:
effective_batch_size = ( effective_batch_size = (
self.args.per_device_train_batch_size self.args.per_device_train_batch_size
@@ -145,16 +159,16 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
* self.args.gradient_accumulation_steps * self.args.gradient_accumulation_steps
) )
return SequenceParallelRepeatRandomSampler( return ContextParallelRepeatRandomSampler(
dataset=self.train_dataset, dataset=self.train_dataset,
mini_repeat_count=self.num_generations, mini_repeat_count=self.num_generations,
world_size=self.world_size, world_size=self.world_size,
rank=self.rank, rank=self.rank,
batch_size=effective_batch_size batch_size=effective_batch_size
// self.num_generations // self.num_generations
// self.args.sequence_parallel_degree, // self.args.context_parallel_degree,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree, context_parallel_degree=self.args.context_parallel_degree,
shuffle=True, shuffle=True,
seed=self.args.seed, seed=self.args.seed,
drop_last=True, drop_last=True,
@@ -212,11 +226,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
): ):
self.accelerator.even_batches = False self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism # Return unprepared dataloader if using context parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension). # slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_degree > 1:
return dataloader return dataloader
# Otherwise prepare with accelerator # Otherwise prepare with accelerator
@@ -289,21 +303,21 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text) all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_degree > 1:
# Calculate sequence parallel group information # Calculate context parallel group information
world_size = self.accelerator.num_processes world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree context_parallel_degree = self.args.context_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree num_sp_groups = world_size // context_parallel_degree
# Since processes in the same SP group have the same prompts, we need to ensure # Since processes in the same CP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group # we only take one copy of each prompt from each CP group
ordered_set_of_prompts = [] ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups): for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader) # Get the first process from each CP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree group_leader_rank = sp_group_id * context_parallel_degree
# Extract prompts from this SP group, accounting for num_generations duplicates # Extract prompts from this CP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group # We only need prompts from one rank in each CP group
group_prompts = all_prompts_text[ group_prompts = all_prompts_text[
group_leader_rank group_leader_rank
* len(prompts_text) : (group_leader_rank + 1) * len(prompts_text) : (group_leader_rank + 1)
@@ -316,7 +330,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate # num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually. # prompt individually.
ordered_set_of_prompts = all_prompts_text[ ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree :: self.num_generations * self.args.context_parallel_degree
] ]
with profiling_context(self, "vLLM.generate"): with profiling_context(self, "vLLM.generate"):
@@ -333,28 +347,28 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
) )
else: else:
completion_ids = [None] * ( completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree len(all_prompts_text) // self.args.context_parallel_degree
) )
# Broadcast the completions from the main process to all processes # Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0) completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism # Determine the appropriate slice based on context parallelism
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to) # Calculate CP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group # Calculate the start index for this CP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice # All ranks in the same CP group get the same data slice
process_slice = slice( process_slice = slice(
sp_group_start, sp_group_start,
sp_group_start + len(prompts), sp_group_start + len(prompts),
) )
completion_ids = completion_ids[process_slice] completion_ids = completion_ids[process_slice]
else: else:
# Original behavior for non-sequence parallel case # Original behavior for non-context parallel case
process_slice = slice( process_slice = slice(
self.accelerator.process_index * len(prompts), self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts), (self.accelerator.process_index + 1) * len(prompts),
@@ -564,20 +578,20 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4) advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data # Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1: if self.args.context_parallel_degree > 1:
# Calculate SP group ID (which group of ranks this rank belongs to) # Calculate CP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size sp_group_id = self.accelerator.process_index // self.local_world_size
# Calculate the start index for this SP group # Calculate the start index for this CP group
sp_group_start = sp_group_id * len(prompts) * self.local_world_size sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# All ranks in the same SP group get the same data slice # All ranks in the same CP group get the same data slice
process_slice = slice( process_slice = slice(
sp_group_start, sp_group_start,
sp_group_start + len(prompts), sp_group_start + len(prompts),
) )
else: else:
# Original behavior for non-sequence parallel case # Original behavior for non-context parallel case
process_slice = slice( process_slice = slice(
self.accelerator.process_index * len(prompts), self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts), (self.accelerator.process_index + 1) * len(prompts),

View File

@@ -1,18 +1,17 @@
"""Module for Axolotl trainer optimizer mixin""" """Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from transformers.trainer import Trainer from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.utils.logging import get_logger
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class OptimizerMixin(Trainer): class OptimizerMixin(Trainer):
@@ -199,3 +198,20 @@ class OptimizerMixin(Trainer):
) )
return self.optimizer return self.optimizer
class OptimizerInitMixin:
"""
Mixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not
accept optimizer_cls_and_kwargs as kwarg in constructor.
"""
def __init__(self, *args, **kwargs):
optimizer_cls_and_kwargs = kwargs.pop("optimizer_cls_and_kwargs", None)
super().__init__(*args, **kwargs)
if (
optimizer_cls_and_kwargs
and self.optimizer_cls_and_kwargs is None
and self.optimizer is None
):
self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs

View File

@@ -6,7 +6,6 @@ See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release TODO: Remove when upstream added PR to release
""" """
import logging
import os import os
import random import random
@@ -17,7 +16,9 @@ from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class RngLoaderMixin(Trainer): class RngLoaderMixin(Trainer):

View File

@@ -1,12 +1,11 @@
"""Module for Axolotl trainer scheduler mixin""" """Module for Axolotl trainer scheduler mixin"""
import logging
import torch import torch
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
RexLR, RexLR,
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
@@ -14,7 +13,7 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant, get_cosine_schedule_with_warmup_decay_constant,
) )
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class SchedulerMixin(Trainer): class SchedulerMixin(Trainer):
@@ -80,13 +79,15 @@ class SchedulerMixin(Trainer):
self.lr_scheduler = RexLR( self.lr_scheduler = RexLR(
optimizer=optimizer, optimizer=optimizer,
max_lr=self.args.learning_rate, max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), min_lr=0 if not use_cosine_min_lr else (
self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps, total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
) )
elif use_cosine_quadratic: elif use_cosine_quadratic:
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer, optimizer,
@@ -115,9 +116,11 @@ class SchedulerMixin(Trainer):
return super().create_scheduler(num_training_steps, optimizer=optimizer) return super().create_scheduler(num_training_steps, optimizer=optimizer)
else: else:
if use_cosine_quadratic: if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") LOG.warning(
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler # type: ignore return self.lr_scheduler # type: ignore

View File

@@ -1,7 +1,5 @@
"""Module for TRL PPO trainer""" """Module for TRL PPO trainer"""
from typing import Literal, Union
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from trl import ( from trl import (
@@ -14,6 +12,7 @@ from trl import (
) )
from axolotl.core.trainers.mixins import RngLoaderMixin from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -75,87 +74,19 @@ class TRLPPOTrainer(PPOTrainer):
) )
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
):
""" """
Extend the base ORPOTrainer for axolotl helpers Extend the base ORPOTrainer for axolotl helpers
""" """
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
metrics = {} ):
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
chosen_rewards
).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
rejected_rewards
).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
reward_accuracies
).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
)
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
)
metrics[f"{prefix}log_odds_ratio"] = (
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
)
metrics[f"{prefix}log_odds_chosen"] = (
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
)
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
""" """
Extend the base KTOTrainer for axolotl helpers Extend the base KTOTrainer for axolotl helpers
""" """
@@ -163,89 +94,19 @@ class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
):
""" """
Extend the base CPOTrainer for axolotl helpers Extend the base CPOTrainer for axolotl helpers
""" """
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
forward_output = self.concatenated_forward(model, batch) class AxolotlRewardTrainer(
( RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
policy_chosen_logps, ):
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = (
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
)
metrics[f"{prefix}rewards/rejected"] = (
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
)
metrics[f"{prefix}rewards/accuracies"] = (
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
)
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
.mean()
.item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
)
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
""" """
Extend the base RewardTrainer for axolotl helpers Extend the base RewardTrainer for axolotl helpers
""" """
@@ -253,7 +114,9 @@ class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer): class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
):
""" """
Extend the base trl.PRMTrainer for axolotl helpers Extend the base trl.PRMTrainer for axolotl helpers
""" """

View File

@@ -164,12 +164,6 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"}, metadata={"help": "whether to use sequential sampling for curriculum learning"},
) )
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field( alternate_lr_scheduler_type: Optional[str] = field(
default=None, default=None,
metadata={ metadata={

View File

@@ -1,12 +1,13 @@
"""Module containing Dataset functionality""" """Module containing Dataset functionality"""
import logging
import os import os
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded # We want this to be a wrapper for an existing dataset that we have loaded
@@ -15,7 +16,7 @@ from .prompt_tokenizers import PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use # let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets # the collators later on to pad the datasets
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset): class TokenizedPromptDataset(Dataset):

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import collections import collections
import importlib import importlib
import logging
from typing import TYPE_CHECKING, Callable, OrderedDict, Union from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel from peft import PeftModel
@@ -31,6 +30,9 @@ from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True)
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
@@ -39,31 +41,39 @@ if TYPE_CHECKING:
class BasePlugin: class BasePlugin:
"""Base class for all plugins. Defines the interface for plugin methods. """Base class for all plugins. Defines the interface for plugin methods.
Methods: A plugin is a reusable, modular, and self-contained piece of code that extends
register(cfg): Registers the plugin with the given configuration. the functionality of Axolotl. Plugins can be used to integrate third-party models,
load_datasets(cfg): Loads and preprocesses the dataset for training. modify the training process, or add new features.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but To create a new plugin, you need to inherit from the BasePlugin class and
implement the required methods.
Note:
Plugin methods include:
- register(cfg): Registers the plugin with the given configuration.
- load_datasets(cfg): Loads and preprocesses the dataset for training.
- pre_model_load(cfg): Performs actions before the model is loaded.
- post_model_build(cfg, model): Performs actions after the model is loaded, but
before LoRA adapters are applied. before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, - post_model_load(cfg, model): Performs actions after the model is loaded,
inclusive of any adapters. inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is - post_trainer_create(cfg, trainer): Performs actions after the trainer is
created. created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
returns a learning rate scheduler. returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
training. training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
training. training.
""" """
def __init__(self): def __init__(self):
"""Initializes the BasePlugin.""" """Initializes the BasePlugin."""
def register(self, cfg): # pylint: disable=unused-argument def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration. """Registers the plugin with the given configuration.
Args: Args:
@@ -275,10 +285,11 @@ class PluginManager:
Attributes: Attributes:
plugins: A list of loaded plugins. plugins: A list of loaded plugins.
Methods: Note:
get_instance(): Static method to get the singleton instance of `PluginManager`. Key methods include:
register(plugin_name: str): Registers a new plugin by its name. - get_instance(): Static method to get the singleton instance of `PluginManager`.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. - register(plugin_name: str): Registers a new plugin by its name.
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
""" """
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
@@ -322,12 +333,12 @@ class PluginManager:
ImportError: If the plugin module cannot be imported. ImportError: If the plugin module cannot be imported.
""" """
try: try:
logging.info(f"Attempting to load plugin: {plugin_name}") LOG.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name) plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}") LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError: except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}") LOG.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self) -> list[str]: def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.' """Returns a list of Pydantic classes for all registered plugins' input arguments.'
@@ -534,7 +545,6 @@ class PluginManager:
Args: Args:
cfg: The configuration for the plugins. cfg: The configuration for the plugins.
model: The loaded model.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
plugin.post_train_unload(cfg) plugin.post_train_unload(cfg)

View File

@@ -19,17 +19,16 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team. from Apple's ML team.
""" """
import importlib import importlib
import logging
import torch import torch
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy") LOG = get_logger(__name__, use_environ=True)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using " "Please install cut_cross_entropy with transformers support using "
@@ -76,10 +75,9 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch, cce_patch,
) )
if is_main_process(use_environ=True): LOG.info(
LOG.info( f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" )
)
# The patch checks model_type internally # The patch checks model_type internally
cce_patch(cfg.model_config_type) cce_patch(cfg.model_config_type)

View File

@@ -15,12 +15,13 @@
""" """
Module for handling Cut Cross Entropy input arguments. Module for handling Cut Cross Entropy input arguments.
""" """
import logging
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CutCrossEntropyArgs(BaseModel): class CutCrossEntropyArgs(BaseModel):

View File

@@ -15,23 +15,14 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import ( from transformers.models.mllama.modeling_mllama import (
MLLAMA_INPUTS_DOCSTRING,
_prepare_cross_attention_mask, _prepare_cross_attention_mask,
) )
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
)
def cce_forward( def cce_forward(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None,
@@ -164,10 +155,6 @@ def cce_forward(
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
)
def cce_forward_multimodal( def cce_forward_multimodal(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,

View File

@@ -2,15 +2,15 @@
Grokfast plugin for Axolotl Grokfast plugin for Axolotl
""" """
import logging
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger
from ..base import BasePlugin from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast") LOG = get_logger(__name__)
class GrokfastCallbackHandler(TrainerCallback): class GrokfastCallbackHandler(TrainerCallback):

View File

@@ -19,16 +19,15 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight. It is designed to be performant, correct, and light-weight.
""" """
import inspect import inspect
import logging
import sys import sys
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable from .utils import patch_with_compile_disable
LOG = logging.getLogger("axolotl.integrations.liger") LOG = get_logger(__name__, use_environ=True)
class LigerPlugin(BasePlugin): class LigerPlugin(BasePlugin):
@@ -85,10 +84,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters: elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation kwargs["swiglu"] = cfg.liger_glu_activation
if is_main_process(use_environ=True): LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
apply_liger_fn(**kwargs) apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba": elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba from transformers.models.jamba import modeling_jamba
@@ -124,9 +120,9 @@ class LigerPlugin(BasePlugin):
if cfg.liger_rope: if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA. # The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.") LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.") LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
@@ -175,7 +171,17 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm, rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm, layer_norm=cfg.liger_layer_norm,
) )
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite
apply_liger_kernel_to_granite(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
else: else:
logging.warning( LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
) )

View File

@@ -15,12 +15,13 @@
""" """
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
import logging
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class LigerArgs(BaseModel): class LigerArgs(BaseModel):

View File

@@ -3,7 +3,6 @@ Sparse Finetuning plugin for Axolotl — enables handling of sparse neural netwo
by maintaining masks for zero weights during training. by maintaining masks for zero weights during training.
""" """
import logging
from functools import wraps from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
@@ -16,11 +15,12 @@ from transformers.trainer_callback import TrainerCallback, TrainerControl, Train
from transformers.training_args import TrainingArguments from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
P = ParamSpec("P") # Params for generic function signatures P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor") LOG = get_logger(__name__)
class LLMCompressorCallbackHandler(TrainerCallback): class LLMCompressorCallbackHandler(TrainerCallback):

View File

@@ -17,14 +17,16 @@ Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
""" """
import json import json
import logging
import requests import requests
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401 from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
LOG = get_logger(__name__)
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5): def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
unfrozen_parameters = {} unfrozen_parameters = {}
@@ -83,17 +85,17 @@ class SpectrumPlugin(BasePlugin):
except FileNotFoundError: except FileNotFoundError:
pass pass
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}") LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data: if not snr_data:
try: try:
snr_data = requests.get(snr_url, timeout=60).json() snr_data = requests.get(snr_url, timeout=60).json()
except requests.exceptions.RequestException as exc: except requests.exceptions.RequestException as exc:
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}") LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
return return
# also catch json parsing errors # also catch json parsing errors
except json.JSONDecodeError as exc: except json.JSONDecodeError as exc:
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}") LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
return return
unfrozen_parameters = _generate_unfrozen_params_yaml( unfrozen_parameters = _generate_unfrozen_params_yaml(

View File

@@ -280,19 +280,19 @@ class LoRA_MLP(torch.autograd.Function):
# Initialize and compute LoRA gradients # Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
if down_A is not None: if down_A is not None and down_B is not None:
d_down_A = h.t() @ (grad_output @ down_B.t()) d_down_A = h.t() @ (grad_output @ down_B.t())
d_down_B = (down_A.t() @ h.t()) @ grad_output d_down_B = (down_A.t() @ h.t()) @ grad_output
d_down_A *= down_scale d_down_A *= down_scale
d_down_B *= down_scale d_down_B *= down_scale
if up_A is not None: if up_A is not None and up_B is not None:
d_up_A = X.t() @ (grad_up @ up_B.t()) d_up_A = X.t() @ (grad_up @ up_B.t())
d_up_B = (up_A.t() @ X.t()) @ grad_up d_up_B = (up_A.t() @ X.t()) @ grad_up
d_up_A *= up_scale d_up_A *= up_scale
d_up_B *= up_scale d_up_B *= up_scale
if gate_A is not None: if gate_A is not None and gate_B is not None:
d_gate_A = X.t() @ (grad_gate @ gate_B.t()) d_gate_A = X.t() @ (grad_gate @ gate_B.t())
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
d_gate_A *= gate_scale d_gate_A *= gate_scale
@@ -311,7 +311,7 @@ class LoRA_MLP(torch.autograd.Function):
del up_weight del up_weight
# Note the .to(dtype) only where mixing LoRA with base weights # Note the .to(dtype) only where mixing LoRA with base weights
if up_A is not None: if up_A is not None and up_B is not None:
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients # Gate projection gradients
@@ -319,7 +319,7 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_gate @ gate_weight.t() dX += grad_gate @ gate_weight.t()
del gate_weight del gate_weight
if gate_A is not None: if gate_A is not None and gate_B is not None:
dX += ( dX += (
grad_gate grad_gate
@ gate_B.to(dtype).t() @ gate_B.to(dtype).t()

View File

@@ -1,6 +1,5 @@
"""Adapter loading functionality, including LoRA / QLoRA and associated utils""" """Adapter loading functionality, including LoRA / QLoRA and associated utils"""
import logging
import os import os
import types import types
from typing import Any from typing import Any
@@ -21,8 +20,9 @@ from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def setup_quantized_meta_for_peft(model: torch.nn.Module): def setup_quantized_meta_for_peft(model: torch.nn.Module):

View File

@@ -3,7 +3,6 @@ models.
""" """
import gc import gc
import logging
import math import math
import os import os
from functools import cached_property from functools import cached_property
@@ -15,7 +14,13 @@ import torch
import transformers import transformers
import transformers.modeling_utils import transformers.modeling_utils
from accelerate import init_empty_weights from accelerate import init_empty_weights
from peft import PeftConfig, PeftMixedModel, PeftModel, prepare_model_for_kbit_training from peft import (
PeftConfig,
PeftMixedModel,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForVision2Seq, AutoModelForVision2Seq,
@@ -47,10 +52,11 @@ from axolotl.utils.distributed import (
get_device_count, get_device_count,
get_device_type, get_device_type,
) )
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance() PLUGIN_MANAGER = PluginManager.get_instance()
@@ -139,7 +145,7 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled.""" """Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora" return self.cfg.fsdp and self.cfg.adapter == "qlora"
def load(self) -> tuple[PreTrainedModel, PeftConfig | None]: def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches. """Load and prepare the model with all configurations and patches.
Returns: Returns:
@@ -191,6 +197,7 @@ class ModelLoader:
self._adjust_model_config() self._adjust_model_config()
self._log_memory_usage() self._log_memory_usage()
self._configure_embedding_dtypes() self._configure_embedding_dtypes()
self._configure_qat()
def _resize_token_embeddings(self): def _resize_token_embeddings(self):
"""Resize token embeddings if needed.""" """Resize token embeddings if needed."""
@@ -305,6 +312,19 @@ class ModelLoader:
before_kbit_train_or_finetune=False, before_kbit_train_or_finetune=False,
) )
def _configure_qat(self):
"""Configure QAT."""
if self.cfg.qat:
from axolotl.utils.quantization import prepare_model_for_qat
prepare_model_for_qat(
self.model,
self.cfg.qat.weight_dtype,
self.cfg.qat.group_size,
self.cfg.qat.activation_dtype,
self.cfg.qat.quantize_embedding,
)
def _load_adapters(self) -> PeftConfig | None: def _load_adapters(self) -> PeftConfig | None:
"""Load LoRA or other adapters.""" """Load LoRA or other adapters."""
# Load LoRA or adapter # Load LoRA or adapter
@@ -536,11 +556,18 @@ class ModelLoader:
if self.cfg.low_cpu_mem_usage: if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True self.model_kwargs["low_cpu_mem_usage"] = True
def _configure_zero3_memory_efficient_loading(self): def _configure_zero3_memory_efficient_loading(
"""Set the deepspeed config to load the model into RAM first before moving self,
to VRAM. ) -> HfTrainerDeepSpeedConfig | None:
"""
Set the deepspeed config to load the model into RAM first before moving to VRAM.
We need to return `hf_ds_cfg` as it needs to exist before model loading. IMPORTANT
==========
We need to return `hf_ds_cfg` as it needs to exist before model loading for zero3.
HfTrainerDeepSpeedConfig is a class that is used to configure the DeepSpeed training.
It is not passed anywhere in the model loading function, just need to exist.
""" """
hf_ds_cfg = None hf_ds_cfg = None
@@ -605,7 +632,8 @@ class ModelLoader:
if "device_map" in self.model_kwargs: if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] del self.model_kwargs["device_map"]
self._configure_zero3_memory_efficient_loading() # Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
# Load model with random initialization if specified # Load model with random initialization if specified
if self.cfg.random_init_weights: if self.cfg.random_init_weights:
@@ -675,7 +703,8 @@ class ModelLoader:
if "device_map" in self.model_kwargs: if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] del self.model_kwargs["device_map"]
self._configure_zero3_memory_efficient_loading() # Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,

View File

@@ -4,7 +4,6 @@ Applies pre- and post-model load patches for various fixes and optimizations.
""" """
import importlib.util import importlib.util
import logging
from functools import cached_property from functools import cached_property
import addict import addict
@@ -17,8 +16,9 @@ from axolotl.monkeypatch.multipack import (
patch_for_multipack, patch_for_multipack,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance() PLUGIN_MANAGER = PluginManager.get_instance()
@@ -59,9 +59,10 @@ class PatchManager:
self._apply_gradient_checkpointing_patches() self._apply_gradient_checkpointing_patches()
self._patch_attention() self._patch_attention()
self._apply_multipack_patches() self._apply_multipack_patches()
self._patch_loss_llama()
self._patch_llama_derived_model() self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch() self._apply_mistral_cross_entropy_patch()
self._apply_unsloth_self_attention_patch() self._apply_self_attention_lora_patch()
def apply_post_model_load_patches(self, model: PreTrainedModel): def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance.""" """Apply patches that require the model instance."""
@@ -80,9 +81,9 @@ class PatchManager:
def _apply_fsdp_patches(self): def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations.""" """Apply patches for FSDP configurations."""
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
patch_accelerate_fsdp_utils() patch_accelerate_fsdp2()
def _apply_adapter_patches(self): def _apply_adapter_patches(self):
"""Apply patches for adapter configurations.""" """Apply patches for adapter configurations."""
@@ -115,13 +116,6 @@ class PatchManager:
patch_llama4_linearized_modeling() patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,
)
patch_gemma3conditionalgeneration_forward()
def _apply_fp8_patches(self): def _apply_fp8_patches(self):
"""Apply patches for FP8 support.""" """Apply patches for FP8 support."""
if self.cfg.fp8: if self.cfg.fp8:
@@ -169,9 +163,9 @@ class PatchManager:
patch_mistral_cross_entropy() patch_mistral_cross_entropy()
def _apply_unsloth_self_attention_patch(self): def _apply_self_attention_lora_patch(self):
"""Apply Unsloth self-attention patches if configured.""" """Apply self-attention LoRA patches if configured."""
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg) patch_self_attn_lora(self.cfg)
@@ -206,19 +200,11 @@ class PatchManager:
has_remote_code=has_remote_code, has_remote_code=has_remote_code,
) )
if self.cfg.is_llama_derived_model:
self._patch_loss_llama()
def _patch_attention(self): def _patch_attention(self):
"""Apply attention-specific patches based on model type.""" """Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
return return
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
from axolotl.monkeypatch.attention.mllama import patch_mllama
patch_mllama()
if self.model_config.model_type == "btlm": if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import ( from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn, replace_btlm_attn_with_flash_attn,
@@ -235,6 +221,9 @@ class PatchManager:
def _patch_loss_llama(self): def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models.""" """Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:
return
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy, patch_fa_llama_cross_entropy,
@@ -314,8 +303,6 @@ class PatchManager:
and (self.cfg.flash_attention or self.cfg.flex_attention) and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
self._patch_loss_llama()
if self.cfg.flash_attention: if self.cfg.flash_attention:
self._patch_llama_flash_attention(packed=self.cfg.sample_packing) self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
elif self.cfg.xformers_attention: elif self.cfg.xformers_attention:

View File

@@ -1,6 +1,5 @@
"""Processor loading functionality for multi-modal models""" """Processor loading functionality for multi-modal models"""
import logging
from typing import Any from typing import Any
import transformers import transformers
@@ -10,8 +9,9 @@ from transformers import (
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):

View File

@@ -1,7 +1,6 @@
"""Tokenizer loading functionality and associated utils""" """Tokenizer loading functionality and associated utils"""
import json import json
import logging
import os import os
import transformers import transformers
@@ -19,8 +18,9 @@ from axolotl.utils.distributed import (
is_local_main_process, is_local_main_process,
is_main_process, is_main_process,
) )
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance() PLUGIN_MANAGER = PluginManager.get_instance()

View File

@@ -1,7 +1,6 @@
"""Utilities for axolotl.loaders module""" """Utilities for axolotl.loaders module"""
import contextlib import contextlib
import logging
from typing import Type from typing import Type
import addict import addict
@@ -9,8 +8,9 @@ import torch
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def get_module_class_from_name( def get_module_class_from_name(

View File

@@ -2,14 +2,56 @@
Common logging module for axolotl Common logging module for axolotl
""" """
import logging
import os import os
import sys import sys
from logging import Formatter from logging import Formatter, Logger, LogRecord
from logging.config import dictConfig from logging.config import dictConfig
from typing import Any, Dict from typing import Any, Dict
from colorama import Fore, Style, init from colorama import Fore, Style, init
DEFAULT_AXOLOTL_LOG_LEVEL = "INFO"
DEFAULT_LOG_LEVEL = "WARNING"
class AxolotlOrWarnErrorFilter(logging.Filter):
"""
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.axolotl_level = logging.getLevelNamesMapping()[
os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL)
]
self.other_level = logging.getLevelNamesMapping()[
os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
]
def filter(self, record: LogRecord) -> bool:
# General filter
if record.levelno >= self.other_level:
return True
# Axolotl filter
return (
record.name.startswith("axolotl") and record.levelno >= self.axolotl_level
)
class AxolotlLogger(Logger):
"""A Logger that automatically rejects non-axolotl INFOs."""
def __init__(self, name: str, level: int = logging.NOTSET):
super().__init__(name, level)
# set global filter on the logger itself
self.addFilter(AxolotlOrWarnErrorFilter())
class ColorfulFormatter(Formatter): class ColorfulFormatter(Formatter):
""" """
@@ -55,11 +97,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"stream": sys.stdout, "stream": sys.stdout,
}, },
}, },
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, # log level will be superseded by the AxolotlLogger
"root": {
"handlers": ["console"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
},
"loggers": { "loggers": {
"axolotl": { "axolotl": {
"handlers": ["color_console"], "handlers": ["color_console"],
"level": "DEBUG", "level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL),
"propagate": False, "propagate": False,
}, },
}, },
@@ -70,3 +116,8 @@ def configure_logging():
"""Configure with default logging""" """Configure with default logging"""
init() # Initialize colorama init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG) dictConfig(DEFAULT_LOGGING_CONFIG)
logging.setLoggerClass(AxolotlLogger)
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
if "ACCELERATE_LOG_LEVEL" not in os.environ:
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)

View File

@@ -1,13 +1,14 @@
""" """
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
""" """
import logging
import sys import sys
import torch import torch
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
@@ -17,27 +18,65 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
Args: Args:
accelerator (`Accelerator`): The accelerator instance accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to load the state dict into model (`torch.nn.Module`):
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
full_sd (`dict`): The full state dict to load, can only be on rank 0 full_sd (`dict`): The full state dict to load, can only be on rank 0
""" """
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor from torch.distributed.tensor import distribute_tensor
LOG.info("Broadcasting full state dict to all ranks...") # Model was previously copied to meta device
sharded_sd = model.state_dict() meta_sharded_sd = model.state_dict()
param_names = sorted(sharded_sd.keys()) sharded_sd = {}
# Rank 0 distributes the full state dict to other ranks
def _infer_parameter_dtype(model, param_name, empty_param):
try:
old_param = model.get_parameter_or_buffer(param_name)
except AttributeError:
# Need this for LORA, as there some params are not *parameters* of sorts
base_param_name, local_param_name = param_name.rsplit(".", 1)
submodule = model.get_submodule(base_param_name)
old_param = getattr(submodule, local_param_name)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
casting_dtype = None
is_param_float8_e4m3fn = (
is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
)
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype
def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if to_contiguous:
tensor = tensor.contiguous()
return tensor
param_names = sorted(meta_sharded_sd.keys())
for param_name in param_names: for param_name in param_names:
mesh = sharded_sd[param_name].device_mesh mesh = meta_sharded_sd[param_name].device_mesh
if accelerator.is_main_process: if accelerator.is_main_process:
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
full_param = full_sd[param_name].detach().cuda() full_param = full_sd[param_name].detach().cuda()
dist.broadcast(full_param, src=0, group=mesh.get_group()) dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor( sharded_tensor = distribute_tensor(
full_param, mesh, sharded_sd[param_name].placements full_param, mesh, sharded_sd[param_name].placements
) )
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_param,
)
sharded_tensor = _cast_and_contiguous(
sharded_tensor, to_contiguous, casting_dtype
)
sharded_sd[param_name] = sharded_tensor sharded_sd[param_name] = sharded_tensor
else: else:
# Prepare a tensor of matching shape and dtype
full_tensor = torch.empty( full_tensor = torch.empty(
sharded_sd[param_name].size(), sharded_sd[param_name].size(),
device="cuda", device="cuda",
@@ -47,12 +86,113 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
sharded_tensor = distribute_tensor( sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_sd[param_name].placements full_tensor, mesh, sharded_sd[param_name].placements
) )
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_tensor,
)
sharded_tensor = _cast_and_contiguous(
sharded_tensor, to_contiguous, casting_dtype
)
sharded_sd[param_name] = sharded_tensor sharded_sd[param_name] = sharded_tensor
# we set `assign=True` because our params are on meta device
model.load_state_dict(sharded_sd, assign=True) model.load_state_dict(sharded_sd, assign=True)
return model
def patch_accelerate_fsdp_utils(): def get_state_dict(self, model, unwrap=True):
"""
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
precision.
Args:
model (`torch.nn.Module`):
A PyTorch model sent through [`Accelerator.prepare`]
unwrap (`bool`, *optional*, defaults to `True`):
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
Returns:
`dict`: The state dictionary of the model potentially without full precision.
Example:
```python
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> net = torch.nn.Linear(2, 2)
>>> net = accelerator.prepare(net)
>>> state_dict = accelerator.get_state_dict(net)
```
"""
from accelerate import DistributedType
from accelerate.utils import compare_versions
if self.distributed_type == DistributedType.DEEPSPEED:
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
tp_sharding = (
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
)
if zero3_sharding or tp_sharding:
if model.zero_gather_16bit_weights_on_model_save():
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
raise ImportError(
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
)
state_dict = (
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
if tp_sharding
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
)
else:
raise ValueError(
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
)
else:
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
state_dict = clone_tensors_for_torch_save(
self.unwrap_model(model).state_dict()
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
):
state_dict = model.state_dict()
else:
if unwrap:
model = self.unwrap_model(model)
state_dict = model.state_dict()
return state_dict
def patch_accelerate_fsdp2():
import accelerate
from accelerate.utils import fsdp_utils from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
@@ -61,3 +201,10 @@ def patch_accelerate_fsdp_utils():
"fsdp2_load_full_state_dict", "fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict, fsdp2_load_full_state_dict,
) )
accelerate.Accelerator.get_state_dict = get_state_dict
setattr(
sys.modules["accelerate"],
"Accelerator.get_state_dict",
get_state_dict,
)

View File

@@ -1,230 +0,0 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.models.mllama.modeling_mllama import (
MllamaTextCrossAttention,
MllamaTextSelfAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import is_flash_attn_greater_or_equal_2_10
class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
"""
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[ # pylint: disable=unused-argument
torch.Tensor
] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Apply Flash Attention
dropout_rate = self.dropout if self.training else 0.0
output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=False,
return_attn_probs=output_attentions,
)
attn_output = output.contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
"""
Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
super().__init__(config, layer_idx, *args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x num_heads x head_dim
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = (
self.config._pre_quantization_dtype # pylint: disable=protected-access
)
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def patch_mllama():
from transformers.models.mllama.modeling_mllama import (
MLLAMA_TEXT_ATTENTION_CLASSES,
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
MLLAMA_VISION_ATTENTION_CLASSES,
MllamaPreTrainedModel,
)
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
MllamaTextCrossFlashAttention2
)
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
)

View File

@@ -3,7 +3,6 @@ Flash attention monkey patch for cerebras btlm model
""" """
import importlib import importlib
import logging
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
@@ -11,7 +10,9 @@ from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
LOG = logging.getLogger("axolotl") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):

View File

@@ -1,230 +0,0 @@
"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
is_torchdynamo_compiling,
)
from transformers.utils.deprecation import deprecate_kwarg
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
def new_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token is OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_index,
dtype=torch.long,
device=inputs_embeds.device,
)
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(
input_ids == self.pad_token_id, self.config.ignore_index, labels
)
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
attention_mask,
token_type_ids,
past_key_values,
cache_position,
inputs_embeds,
is_training,
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
if attention_mask is not None:
# Get the shifted attention mask
shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to(
logits.device
) # +1 for shift
# Filter logits and labels based on attention mask
valid_indices = shift_attention_mask != 0
filtered_logits = logits[..., :-1, :][valid_indices]
filtered_labels = labels[..., 1:][valid_indices.to(labels.device)]
# TODO: do we need to handle num_items_in_batch given we filter the logits and labels?
loss = self.loss_function(
logits=filtered_logits,
labels=None, # we pass shift_labels
shift_labels=filtered_labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
else:
# Standard case without filtering
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_gemma3conditionalgeneration_forward():
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForConditionalGeneration,
)
Gemma3ForConditionalGeneration.forward = new_forward

View File

@@ -18,7 +18,6 @@ DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
import atexit import atexit
import concurrent.futures import concurrent.futures
import logging
import os import os
import queue import queue
import shutil import shutil
@@ -32,11 +31,13 @@ from typing import Dict
import torch import torch
from axolotl.utils.logging import get_logger
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger # Setup logger
logger = logging.getLogger(__name__) logger = get_logger(__name__)
class DiskOffloadManager: class DiskOffloadManager:

View File

@@ -2,7 +2,6 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import logging
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -25,6 +24,7 @@ from transformers.models.llama.modeling_llama import (
) )
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.utils.logging import get_logger
try: try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -41,7 +41,7 @@ except ImportError:
) )
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
def is_xformers_available() -> bool: def is_xformers_available() -> bool:
@@ -612,9 +612,10 @@ def generate_qkv(
q, query_padding_mask q, query_padding_mask
) )
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 def output_pad_fn(output_unpad):
output_unpad, indices_q, batch_size, seqlen_q return pad_input( # noqa: E731
) output_unpad, indices_q, batch_size, seqlen_q
)
else: else:
q_unpad = rearrange(q, "b s h d -> (b s) h d") q_unpad = rearrange(q, "b s h d -> (b s) h d")
@@ -627,9 +628,10 @@ def generate_qkv(
) )
max_seqlen_q = seqlen_q max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 def output_pad_fn(output_unpad):
output_unpad, "(b s) h d -> b s h d", b=batch_size return rearrange( # noqa: E731
) output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None: if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)

View File

@@ -2,7 +2,6 @@
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
""" """
import logging
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -11,10 +10,14 @@ import torch.nn.functional as F
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try: try:
import xformers.ops import xformers.ops
except ImportError: except ImportError:
logging.error("xformers not found! Please install it before trying to use it.") LOG.error("xformers not found! Please install it before trying to use it.")
def hijack_llama_attention(): def hijack_llama_attention():

Some files were not shown because too many files have changed in this diff Show More