Compare commits

...

60 Commits

Author SHA1 Message Date
Dan Saunders
4f39aeefb9 debug 2025-06-09 20:38:46 +00:00
Dan Saunders
8f75136ad3 debug 2025-06-09 20:38:13 +00:00
Dan Saunders
70e9cb545d update mistral dep version 2025-06-09 18:03:45 +00:00
Dan Saunders
aa236a4669 use from_hf_hub 2025-06-09 18:03:43 +00:00
Dan Saunders
65f8988efd small changes 2025-06-09 18:03:31 +00:00
Dan Saunders
13ddb8f172 Simplify mistral tokenizer identification (depends on upstream PR) 2025-06-09 18:03:31 +00:00
Dan Saunders
b1570ed0fa update 2025-06-09 18:03:31 +00:00
Dan Saunders
9581a9efed refactor tokenizer loader + add mistral logic 2025-06-09 18:03:28 +00:00
Dan Saunders
7e44445494 add mistral-common dep 2025-06-09 18:02:28 +00:00
Wing Lian
09c685fd2c fix worker_init_fn signature handling (#2769) 2025-06-08 23:14:10 -07:00
Wing Lian
7909bfb076 add manual seed for flaky test_geglu_backward test (#2763) [skip ci] 2025-06-05 09:23:17 -07:00
Wing Lian
cb03c765a1 add uv tooling for e2e gpu tests (#2750)
* add uv tooling for e2e gpu tests

* fixes from PR feedback

* simplify check

* fix env var

* make sure to use uv for other install

* use raw_dockerfile_image

* Fix import

* fix args to experimental dockerfile image call

* use updated modal versions
2025-06-05 07:25:06 -07:00
Timofey Klyubin
4440b4a1ce remove unused field for chat_template.default for DPO training (#2755) [skip ci]
* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

* lint and fix tests for new return value

* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

lint and fix tests for new return value

fix for updated expected fields for dpo

remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

fix test still expecting "messages" field

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-05 07:22:58 -07:00
NanoCode012
e8e45b3441 fix: remove hqq (#2759) [skip ci] 2025-06-05 07:22:23 -07:00
Wing Lian
c67910fa6f bump hf deps (#2735) [skip ci]
* bump hf deps

* upgrade liger-kernel too

* install cce from fork for transformers fix

* fix reference to vocab size in gemma3 patch

* use padding_idx instead of pad_token_id

* remove fixed gemma3 patch

* use updated cce fork

* fix local mllama cce patches w docstring

* add test for multipack with trainer setup and fix trainer for trainer refactor upstream

* bump modal version

* guard for iterable datasetS

* mllama model arch layout changed in latest transformers

* fix batch sampler with drop_last

* fix: address upstream vlm changes for lora

* fix: update references to old lora target path

* fix: remove mllama fa2 patch due to upstream fix

* fix: lora kernel patch path for multimodal models

* fix: removed mllama from quarto

* run test for came optim on 2.6.0+

* fix fsdp2 patch and remove deprecated patch

* make sure to set sequence_parallel_degree for grpo

* Add SP test for GRPO

* add sp to grpo config for trainer

* use reward_funcs as kwarg to grpo trainer

* fix the comprehension for reward funcs

* reward funcs already passed in as args

* init sp_group right before training

* fix check for adding models to SP context

* make sure to pass args to super

* upgrade deepspeed

* use updated trl and add reasoning flags for vllm

* patch the worker

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-06-05 07:20:33 -07:00
NanoCode012
787880215b fix(deepspeed): deepspeed config not being set for z3 (#2754)
* fix(deepspeed): deepspeed config not being set for z3

* fix: comments
2025-06-03 14:27:09 -07:00
NanoCode012
4b1a29c694 feat(modal): update docker tag to use torch2.6 from torch2.5 (#2749) [skip ci] 2025-06-03 14:26:07 -07:00
NanoCode012
d7fa60662e feat: add chat_template kwargs (#2694) [skip ci] 2025-06-03 14:25:26 -07:00
Dan Saunders
1d91d905c9 remove deprecated wandb env var (#2751)
* remove deprecated wandb env var

* remove os.environ wandb setting; unused loggers

* remove os.environ wandb setting; unused loggers
2025-06-03 14:04:15 -07:00
mhenrhcsen
2bf61d8e25 fix abbriviatation spelling error 2025-06-03 21:30:40 +02:00
mhenrhcsen
68788e419e feat: add Group Relative Policy Optimization (GPRO) to RLHF documentation 2025-06-03 21:30:40 +02:00
github-actions[bot]
94219f6ee8 chore: update pre-commit hooks (#2745)
* chore: update pre-commit hooks

* trigger linter when pre commit hooks are updated

* fix type checks from upgraded pre-commit

---------

Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-02 15:54:29 -07:00
Wing Lian
ecc719f5c7 add support for base image with uv (#2691) 2025-06-02 12:48:55 -07:00
NanoCode012
d5d0dc5938 fix: suppress non-axolotl logs unless it's warning or higher (#2724)
* fix: increase log level for root loggers and axolotl's

* fix: BasePlugin using wrong logger

* fix: update logger to take name from module

* feat: change logger class to AxolotlLogger to filter non-axolotl infos or below

* fix: change behavior to not disable existing loggers

* fix: update logging to respect correct env

* chore: fix comment

* fix: suppress accelerate log to LOG_LEVEL if not set

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-05-31 12:13:43 +07:00
NanoCode012
5e86c35322 fix(log): remove duplicate merge_lora param (#2742) [skip ci] 2025-05-31 12:13:31 +07:00
NanoCode012
6778856804 Fix: RL base feature parity (#2133)
* feat: add num_proc and load from cache for rl mapping

* fix: refactor sft and rl trainer to set same base args

* feat: add report_to to set run name

* fix: consolidate handling of fp16, bf16, tf32 kwarg

* chore: consolidate eval_strat, loraplus, lr sched, max_length

* fix: deprecate old types

* fix: adding missing Any

* fix: max_steps incorrectly set

* fix: remove unnecessary datacollator kwarg insert and pop

* fix: update default max_steps

* fix: add missing weight_decay handling

* fix: ignore max_length for grpo

* feat: update CI on trainer_builder

* fix: comments

* improve handling of warmup/logging steps

* use transformers default for logging steps, not None

* fix: remove redundant override

* fix: lint

* feat: allow custom optim for rl methods

* fix: duplicate optim setting

* fix(test): set sequence_parallel_degree default in base cfg

* feat: add handling for seed and SP/ring-attn config

* chore: add back return typing from rebase

* fix(test): use RLType directly to skip needing to validate

* feat: split training builder into sub modules

* fix: remove deprecated clause

* chore: add missing config to doc

* fix: update quarto autodoc

* fix: import path for trainer builder and submodules

* fix: remove redundant configs from rebase mistake

* chore: simplify dynamo check

* fix: optimizer_cls_and_kwargs to be passed into trainer_kwargs

* fix: add missing rex from rebase

* fix: move pop optimizer_cls_and_kwargs

* fix: pop optimizer cls in rl too

* fix: leftover bug from rebase

* fix: update handling of trainer_cls in RL

* fix: address pr feedback

* feat: call hook_pre_create_trainer for rl

* chore: lint

* fix: return notimplemented for ppo

* feat: moved torch compile to base and refactor collator setting

* chore: remove unused importlib.util import

* fix: optimizer cls not being popped

* feat: move epoch setting to base

* fix: catch unhandled custom optimizer

* fix: remove duplicate lora plus setting

* chore: refactor if condition

* chore: refactor set_base_training_args into smaller modules

* fix: address TrainerBuilderBase class variables to instance var

* fix: add handling for beta3 and episilon2

* fix: change to pass dict via arg instead of updating dict

* chore: simplify if condition

* fix: force access to lr & weight decay in case not provided to early error

* fix: remove log sweep

* chore: refactor if condition

* fix: address renamed cfg

* fix: improve handling of cosine hyp

* fix: remove unused params

* chore: refactor

* chore: clarify doc safetensors

* fix: update import path to be unified following comments

* fix: duplicate kwargs passed

* feat: return separate trainer_kwargs

* chore: refactor

* chore: refactor based on comments

* chore: refactor based on comments

* fix: move gpustats callback to base

* chore: create trainer_cls_args first based on comments

* fix: ipo label smoothing passed incorrectly

* feat: add optimizer parity for RL methods with test

* feat: add parity for optimizer in RM/PRM and add test

* fix: remove redundant function override for orpo/cpo batch metrics

* fix: improve handling of dpo_label_smoothing and merge issue

* fix: test fixture returning wrong field

* fix: address avoid direct modify fixture

* chore: minor refactor

* Revert "chore: refactor"

This reverts commit 99c8859eb0.

* feat: rename trainer_builder to builders

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-30 11:21:47 +07:00
Wing Lian
ec4ebfd997 Add a few items to faq (#2734)
* Add a few items to faq

* formatting

* chore: lint
2025-05-28 16:20:19 -04:00
Dan Saunders
bde8b5b6bd fix dist state init before deepspeed setup (#2737) 2025-05-28 14:59:57 -04:00
Dan Saunders
2962a398b7 Lora kernels fix (#2732)
* fix lora kernel patching and improve test

* simplification
2025-05-28 10:03:43 -04:00
salman
65c5481120 Rank 0-only logging (#2608)
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-28 14:57:30 +01:00
salman
5fca214108 QAT (#2590)
QAT and quantization w/torchao
2025-05-28 12:35:47 +01:00
NanoCode012
20fda75917 feat(doc): add google analytics to docs (#2708) 2025-05-28 15:51:21 +07:00
NanoCode012
6b6370f4e3 feat(doc): add info on how to use dapo / dr grpo and misc doc fixes (#2673) [skip ci]
* feat(doc): add info on how to use dapo / dr grpo

* chore: add missing config to docs

* fix: missing comment

* fix: add missing scheduler from schema

* chore: refactor lr scheduler docs

* fix: remove log_sweep
2025-05-28 15:51:04 +07:00
mashdragon
add2025253 Fix Mistral chat template (mistral_v7_tekken) (#2710) [skip ci]
Per 4b8dd8aae7 (d2h-482763)
2025-05-28 15:50:47 +07:00
artem
a703560a10 add two checks to handle legacy format interleaved multimodal ds (#2721) [skip ci]
* add two checks to handle legacy format interleaved ds

* fix: add warning about multiple image using legacy format

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-28 15:49:43 +07:00
NOHHYEOB, BAE
4a80d309e8 Add chat templates for command-a and aya-23-8B models (#2731) [skip ci]
* Add chat templates for command-a and aya model

* Fix: isolate for-loop update and remove unintended changes
2025-05-28 15:49:16 +07:00
NanoCode012
e33f225434 feat(doc): note lora kernel incompat with RLHF (#2706) [skip ci]
* feat(doc): note lora kernel incompat with RLHF

* fix: add validation following comments

* chore: fix typo following suggestion
2025-05-28 15:48:40 +07:00
NanoCode012
3e6948be97 Fix(doc): clarify data loading for local datasets and splitting samples (#2726) [skip ci]
* fix(doc): remove incorrect json dataset loading method

* fix(doc): clarify splitting only happens in completion mode

* fix: update local file loading on config doc

* fix: typo
2025-05-28 15:48:22 +07:00
github-actions[bot]
4a8af60d34 chore: update pre-commit hooks (#2729)
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-05-27 11:45:31 -04:00
Dan Saunders
a0941a9271 no need to generate diff file (#2728) 2025-05-27 11:44:06 -04:00
Dan Saunders
5eb01f3df1 Fix quarto (#2717)
* missing modules

* fix quarto complaints
2025-05-23 21:16:51 -04:00
xzuyn
d27c35ac44 Liger GraniteMoE (#2715) 2025-05-23 18:40:43 -04:00
Dan Saunders
a535b68043 update quarto for model loading refactor (#2716)
* update quarto for model loading refactor

* fix desc
2025-05-23 16:28:31 -04:00
Dan Saunders
b5f1e53a0f models.py -> loaders/ module refactor (#2680)
* models.py -> loaders/ module refactor

* refactor ModelLoader class

* plugin manager changes

* circular import fix

* pytest

* pytest

* minor improvements

* fix

* minor changes

* fix test

* remove dead code

* coderabbit comments

* lint

* fix

* coderabbit suggestion I liked

* more coderabbit

* review comments, yak shaving

* lint

* updating in light of SP ctx manager changes

* review comment

* review comment 2
2025-05-23 15:51:11 -04:00
Dan Saunders
8cde256db2 Remove unused const (#2714)
* remove unused const

* accidentally commited benchmark plot
2025-05-23 12:27:38 -04:00
Dan Saunders
5f8f817200 SP context manager update (#2699)
* utilize accelerate prepare_data_loader with patching

* lint

* cleanup, fix

* update to support DPO quirk

* coderabbit commits, cleanup, remove dead code

* fix

* move ring attn patching to sp ctx manager

* lint

* lint

* test fix

* test fix
2025-05-22 11:18:32 -04:00
NanoCode012
aa0492c366 feat: do not find turn indices if turn is not trainable (#2696)
* feat: do not find turn indices if turn is not trainable

* fix: handle edge case where train on eos/eot is all

* fix: improve warning message
2025-05-22 19:19:59 +07:00
NanoCode012
798b5f5cfd fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci]
* fix: plugin rl overwrite trainer_cls

* feat(test): add test to catch trainer_cls is not None
2025-05-22 19:19:12 +07:00
NanoCode012
1c83a1a020 feat(doc): clarify minimum pytorch and cuda to use blackwell (#2704) [skip ci] 2025-05-22 19:18:27 +07:00
Dan Saunders
6aa41740df SP dataloader patching + removing custom sampler / dataloader logic (#2686)
* utilize accelerate prepare_data_loader with patching

* lint

* cleanup, fix

* update to support DPO quirk

* small change

* coderabbit commits, cleanup, remove dead code

* quarto fix

* patch fix

* review comments

* moving monkeypatch up one level

* fix
2025-05-21 11:20:20 -04:00
Wing Lian
a27b909c5c GRPO fixes (peft) (#2676)
* don't set peft_config on grpo to prevent double peft wrap

* remove overrides needed to support bug

* fix grpo tests

* require more CPU for multigpu to help with torch compile for vllm
2025-05-16 15:47:03 -04:00
xzuyn
6cb07b9d12 Fix for setting adam_beta3 and adam_epsilon2 for CAME Optimizer (#2654) [skip ci]
* make setting `adam_beta3` and `adam_epsilon2` work correctly

* update config docs so users know args are specific to CAME optim

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:50 -04:00
C080
288653adb6 Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifa… (#2675) [skip ci]
* Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting

* cleanup and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:31 -04:00
NanoCode012
3a5b495a74 Fix: improve doc on merge/inference cli visibility (#2674)
* feat: improve visibility for merge doc

* feat: add tip on reuse config between modes
2025-05-16 13:07:40 -04:00
xzuyn
f661858fc4 Print dataset name (#2668) [skip ci] 2025-05-16 13:06:58 -04:00
Eric Meier
c837c4a424 Add missing init file to liger plugin (#2670) [skip ci] 2025-05-16 13:06:46 -04:00
michelyang
c9797de6bb Add num_proc to fix data set slow processing issue (#2681) [skip ci] 2025-05-16 13:06:20 -04:00
Wing Lian
8f8a7afb05 Add ci and images for CUDA 12.8 for B200s (#2683) [skip ci]
* Add ci and images for CUDA 12.8 for B200s

* add comments explaining CI [skip e2e]
2025-05-16 13:06:08 -04:00
NanoCode012
86472715da fix: remove doc string imports in monkeypatches (#2671) [skip ci] 2025-05-16 13:05:55 -04:00
Wing Lian
c0a0c7534c Activation checkpointing with offloading to disk with prefetch (#2663)
* offload activations to disk instead of CPU RAM

* add prefetch

* Disco :dance:

* include offload_disk in e2e test for AC

* document and make sure to cleanup

* fix annotation to match docs

* fix docs build

* address PR feedback
2025-05-13 16:39:39 -04:00
251 changed files with 8342 additions and 5702 deletions

View File

@@ -17,7 +17,7 @@ jobs:
build-base:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
matrix:
@@ -28,42 +28,50 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: next
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base-nightly"
# # "next" is for release candidates of pytorch
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: next
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-next"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -85,7 +93,59 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDNN_VERSION=${{ matrix.cudnn_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
build-base-uv:
if: github.repository_owner == 'axolotl-ai-cloud'
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -9,6 +9,7 @@ on:
- '.github/workflows/*.yml'
- "*.[q]md"
- "examples/**/*.y[a]?ml"
- ".pre-commit-config.yaml"
workflow_dispatch:
jobs:

View File

@@ -31,6 +31,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -94,6 +99,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -59,7 +59,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -25,7 +25,6 @@ jobs:
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
@@ -39,11 +38,3 @@ jobs:
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -44,98 +44,6 @@ jobs:
env:
SKIP: no-commit-to-branch
# preload-cache:
# name: Preload HF cache
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python_version: ["3.11"]
# pytorch_version: ["2.6.0"]
# timeout-minutes: 20
#
# env:
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
#
# steps:
# - name: Check out repository code
# uses: actions/checkout@v4
#
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
#
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p /home/runner/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
#
# - name: Setup Python
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python_version }}
# cache: 'pip' # caching pip dependencies
#
# - name: upgrade pip
# run: |
# pip3 install --upgrade pip
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
#
# - name: Install PyTorch
# run: |
# pip3 install torch==${{ matrix.pytorch_version }}
#
# - name: Install dependencies
# run: |
# pip3 show torch
# pip3 install --no-build-isolation -U -e .
# python scripts/unsloth_install.py | sh
# python scripts/cutcrossentropy_install.py | sh
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
#
# - name: Make sure PyTorch version wasn't clobbered
# run: |
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
#
# - name: Ensure axolotl CLI was installed
# run: |
# axolotl --help
#
# - name: Pre-Download dataset fixture
# run: |
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
#
# - name: Run tests
# run: |
# pytest -v tests/conftest.py
#
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v5
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# files: ./coverage.xml
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
# fail_ci_if_error: false
#
# - name: cleanup pip cache
# run: |
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
#
# - name: Save HF cache
# id: hf-cache
# uses: actions/cache/save@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
@@ -151,15 +59,6 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
@@ -222,7 +121,6 @@ jobs:
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
@@ -234,15 +132,6 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
@@ -295,6 +184,7 @@ jobs:
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
@@ -311,6 +201,13 @@ jobs:
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -321,7 +218,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -332,6 +229,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -341,6 +239,8 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
@@ -365,6 +265,12 @@ jobs:
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -375,7 +281,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -386,6 +292,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -415,7 +322,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -19,15 +19,15 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.1.2
rev: 7.2.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.6
rev: v3.3.7
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
rev: v1.16.0
hooks:
- id: mypy
additional_dependencies:

View File

@@ -242,16 +242,12 @@
# early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler: # 'one_cycle' | empty for cosine
# lr_scheduler_kwargs:
# # For one_cycle optim
# lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134

View File

@@ -51,7 +51,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.4.1
- PyTorch ≥2.5.1
### Installation

View File

@@ -17,7 +17,9 @@ quartodoc:
- convert
- prompt_tokenizers
- logging_config
- core.trainer_builder
- core.builders.base
- core.builders.causal
- core.builders.rl
- core.training_args
- core.chat.messages
- core.chat.format.chatml
@@ -43,6 +45,7 @@ quartodoc:
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- title: Trainers
desc: Training implementations
contents:
@@ -54,13 +57,21 @@ quartodoc:
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
- core.trainers.utils
- title: Model Loading
desc: Functionality for loading and patching models, tokenizers, etc.
contents:
- loaders.model
- loaders.tokenizer
- loaders.processor
- loaders.adapter
- loaders.patch_manager
- loaders.constants
- title: Mixins
desc: Mixin classes for augmenting trainers
contents:
- core.trainers.mixins.optimizer
- core.trainers.mixins.rng_state_loader
- core.trainers.mixins.scheduler
- core.trainers.mixins.sequence_parallel
- title: Context Managers
desc: Context managers for altering trainer behaviors
contents:
@@ -118,17 +129,16 @@ quartodoc:
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
- monkeypatch.gradient_checkpointing.offload_disk
- title: Utils
desc: Utility functions
contents:
- utils.models
- utils.tokenization
- utils.chat_templates
- utils.lora
- utils.lora_embeddings
- utils.model_shard_quant
- utils.bench
- utils.freeze
@@ -139,7 +149,7 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- utils.quantization
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
@@ -189,12 +199,14 @@ quartodoc:
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
- utils.callbacks.qat
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
favicon: favicon.jpg
google-analytics: "G-9KYCVJBNMQ"
navbar:
logo: image/axolotl_logo_digital_white.svg
title: false
@@ -247,6 +259,8 @@ website:
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- section: "Core Concepts"
contents:

52
cicd/Dockerfile-uv.jinja Normal file
View File

@@ -0,0 +1,52 @@
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py --uv | sh
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
@@ -55,7 +55,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
GPU_CONFIG = f"H100:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60,
cpu=8.0,
cpu=16.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)

View File

@@ -8,8 +8,9 @@ import tempfile
import jinja2
import modal
import modal.experimental
from jinja2 import select_autoescape
from modal import App, Image
from modal import App
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -17,14 +18,15 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
@@ -38,11 +40,11 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
cicd_image = modal.experimental.raw_dockerfile_image(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
# context_mount=None,
force_build=True,
gpu="A10G",
# gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
@@ -55,7 +57,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):

36
docker/Dockerfile-uv-base Normal file
View File

@@ -0,0 +1,36 @@
ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
ENV UV_TORCH_BACKEND="cu${CUDA}"
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& git lfs install --skip-repo \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:${PATH}"
RUN uv python install ${PYTHON_VERSION}
WORKDIR /workspace
RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic

View File

@@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
### quantize
Quantizes a model using the quantization configuration specified in your YAML file.
```bash
axolotl quantize config.yml
```
See [Quantization](./quantize.qmd) for more details.
## Legacy CLI Usage

View File

@@ -65,6 +65,20 @@ bnb_config_kwargs:
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# quantization aware training
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
# post-training quantization
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -98,8 +112,10 @@ plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
# HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -221,7 +237,7 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
Deduplicates datasets and test_datasets with identical entries.
# Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with.
@@ -270,10 +286,25 @@ trl:
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
top_k: # Optional[int]. Top-k sampling for the generation policy.
min_p: # Optional[float]. Minimum probability for the generation policy.
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
# reward modelling: `True` or `False`
@@ -483,6 +514,7 @@ output_dir: ./completed-model
# setting to `auto` will enable torch compile when torch>=2.5.1
torch_compile: # Optional[Union[Literal["auto"], bool]]
torch_compile_backend: # Optional[str]
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
# Training hyperparameters
@@ -529,7 +561,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
# Save model as safetensors (require safetensors package). Default True
save_safetensors:
# Whether to mask out or include the human's prompt from the training labels
@@ -539,7 +571,7 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
@@ -551,7 +583,24 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
# Valid values are driven by the Transformers SchedulerType class, see:
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
# Valid values include
# - 'linear'
# - 'cosine' (default)
# - 'cosine_with_restarts'
# - 'polynomial'
# - 'constant'
# - 'constant_with_warmup'
# - 'inverse_sqrt'
# - 'reduce_lr_on_plateau'
# - 'cosine_with_min_lr'
# - 'warmup_stable_decay'
# Additional schedulers include:
# - 'one_cycle'
# - 'rex'
lr_scheduler:
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -569,7 +618,7 @@ lr_div_factor: # Learning rate div factor
#
# Valid values for 'optimizer' include:
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_fused (default)
# - adamw_torch_xla
# - adamw_torch_npu_fused
# - adamw_apex_fused
@@ -633,7 +682,9 @@ weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm
max_grad_norm:

View File

@@ -36,10 +36,6 @@ It is typically recommended to save your dataset as `.jsonl` due to its flexibil
Axolotl supports loading from a Hugging Face hub repo or from local files.
::: {.callout-important}
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
:::
### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
@@ -77,18 +73,21 @@ datasets:
type: completion
```
From local files (either example works):
From local files:
```yaml
datasets:
- path: A.jsonl
type: completion
- path: json
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
- path: B.jsonl
type: completion
```
::: {.callout-important}
For `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!
:::
### Pre-training dataset configuration tips
#### Setting max_steps

View File

@@ -54,7 +54,7 @@ datasets:
#### Files
Usually, to load a JSON file, you would do something like this:
To load a JSON file, you would do something like this:
```python
from datasets import load_dataset
@@ -66,20 +66,12 @@ Which translates to the following config:
```yaml
datasets:
- path: json
data_files: /path/to/your/file.jsonl
```
However, to make things easier, we have added a few shortcuts for loading local dataset files.
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
```yaml
datasets:
- path: /path/to/your/file.jsonl
- path: data.json
ds_type: json
```
In the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.
This works for CSV, JSON, Parquet, and Arrow files.
::: {.callout-tip}

View File

@@ -8,6 +8,10 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8.
:::
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
@@ -32,7 +36,6 @@ Tags examples:
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
@@ -73,12 +76,10 @@ Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
- `0.9.2`
## Cloud

View File

@@ -110,3 +110,17 @@ description: Frequently asked questions
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
**Q: `Data processing error: CAS service error`**
> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`
**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**
> A: Depending on the version of torch, you may need to include this in your YAML:
> ```yaml
> flex_attn_compile_kwargs:
> dynamic: false
> mode: max-autotune-no-cudagraphs
> ```

View File

@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
format):
```json
@@ -120,6 +120,12 @@ axolotl train my_training.yml
## Common Tasks {#sec-common-tasks}
::: {.callout-tip}
The same yaml file is used for training, inference, and merging.
:::
### Testing Your Model {#sec-testing}
After training, test your model:
@@ -128,6 +134,16 @@ After training, test your model:
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
More details can be found in [Inference](inference.qmd).
### Using a UI {#sec-ui}
Launch a Gradio interface:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first:
@@ -136,14 +152,22 @@ For large datasets, preprocess first:
axolotl preprocess my_training.yml
```
### Using a UI {#sec-ui}
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
Launch a Gradio interface:
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
### Merging LoRA weights {#sec-merging-lora}
To merge the LoRA weights back into the base model, run:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
```
The merged model will be saved in the `{output_dir}/merged` directory.
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to:
@@ -156,6 +180,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
- PyTorch ≥2.5.1
## Installation Methods {#sec-installation-methods}
@@ -25,6 +25,10 @@ Please make sure to have Pytorch installed before installing Axolotl in your loc
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
:::
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
@@ -37,6 +41,40 @@ installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
Install uv if not already installed
```{.bash}
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
then create the venv and activate
```{.bash}
export UV_TORCH_BACKEND=cu126
uv venv --no-project --relocatable
source .venv/bin/activate
```
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
```
### Edge/Development Build {#sec-edge-build}
For the latest features between releases:
@@ -72,6 +110,10 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}

View File

@@ -84,6 +84,10 @@ lora_qkv_kernel: true
lora_o_kernel: true
```
::: {.callout-note}
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -87,20 +87,7 @@ We support sequence parallelism (SP) via the
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}

View File

@@ -43,7 +43,7 @@ datasets:
# leave the vision model and vision tower frozen
# load_in_8bit: true
adapter: lora
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# (optional) if you want to resize images to a set size
image_size: 512

32
docs/qat.qmd Normal file
View File

@@ -0,0 +1,32 @@
---
title: "Quantization Aware Training (QAT)"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
support for QAT and post-training quantization (PTQ) in axolotl.
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
## Configuring QAT in Axolotl
To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.

53
docs/quantize.qmd Normal file
View File

@@ -0,0 +1,53 @@
---
title: "Quantization with torchao"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
::: {.callout-note}
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
:::
## Configuring Quantization in Axolotl
Quantization is configured using the `quantization` key in your configuration file.
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
output_dir: # The path to the output directory.
```
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
you used to train the model:
```yaml
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```
```bash
axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.

View File

@@ -16,7 +16,8 @@ feedback. Various methods include, but not limited to:
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
## RLHF using Axolotl
@@ -582,7 +583,20 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
```yaml
trl:
loss_type: dr_grpo
# Normalizes loss based on max completion length (default: 256)
max_completion_length:
```
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### SimPO

View File

@@ -41,7 +41,7 @@ When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations
## Requirements
@@ -67,9 +67,11 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
...
```

View File

@@ -28,7 +28,7 @@ pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -30,7 +30,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -29,7 +29,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,79 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./outputs/qat_out/
sample_packing: true
pad_to_sequence_len: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
@@ -38,6 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -27,7 +27,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen2.5-0.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
chat_template: qwen_25
rl: dpo
datasets:

View File

@@ -0,0 +1,78 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out/
sequence_len: 2048
sample_packing: true
flex_attention: true
pad_to_sequence_len: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 256
fake_quant_after_n_steps: 1000
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
max_steps: 2000
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:

View File

@@ -6,21 +6,21 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.9
liger-kernel==0.5.10
# END section
packaging==23.2
huggingface_hub==0.31.0
huggingface_hub==0.32.2
peft==0.15.2
transformers==4.51.3
transformers==4.52.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.1
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.1.0
hqq==0.2.5
accelerate==1.7.0
datasets==3.6.0
deepspeed>=0.17.0
trl==0.18.1
hf_xet==1.1.2
mistral-common[hf-hub]==1.6.0
optimum==1.16.2
hf_transfer
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.9.0
torchao==0.10.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6

View File

@@ -9,6 +9,8 @@ except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
USE_UV = "--uv" in sys.argv[1:]
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
@@ -23,7 +25,9 @@ if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
)

View File

@@ -11,7 +11,7 @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
```
cd /workspace

View File

@@ -1,11 +1,15 @@
# noqa
# pylint: skip-file
import sys
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
@@ -31,6 +35,7 @@ elif v < V("2.6.0"):
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print(
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
)

View File

@@ -118,7 +118,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.15.4",
"deepspeed==0.17.0",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -28,7 +28,6 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
@@ -89,6 +88,26 @@ class VllmServeCliArgs:
},
)
enable_reasoning: Optional[bool] = field(
default=None,
)
reasoning_parser: Optional[str] = field(
default=None,
)
@dataclass
class QuantizeCliArgs:
"""Dataclass with CLI arguments for `axolotl quantize` command."""
base_model: Optional[str] = field(default=None)
weight_dtype: Optional[str] = field(default=None)
activation_dtype: Optional[str] = field(default=None)
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
@dataclass
class EvaluateCliArgs:

View File

@@ -1,6 +1,5 @@
"""Various checks for Axolotl CLI."""
import logging
import os
from pathlib import Path
@@ -8,7 +7,9 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_accelerate_default_config() -> None:

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1"
docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"

View File

@@ -1,7 +1,6 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
@@ -22,11 +21,12 @@ from axolotl.utils.config import (
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__, use_environ=True)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
LOG.info("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
LOG.info(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
LOG.info("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
LOG.info("Invalid input. Please enter a number.")
return chosen_file

View File

@@ -1,6 +1,5 @@
"""CLI to run evaluation on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -17,8 +16,9 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:

View File

@@ -1,7 +1,6 @@
"""CLI to run inference on a trained model."""
import importlib
import logging
import sys
from pathlib import Path
from threading import Thread
@@ -22,8 +21,9 @@ from axolotl.utils.chat_templates import (
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def get_multi_line_input() -> str:

View File

@@ -2,7 +2,6 @@
# pylint: disable=redefined-outer-name
import logging
import os
import subprocess # nosec B404
import tempfile
@@ -17,6 +16,7 @@ import axolotl
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
@@ -30,8 +30,11 @@ from axolotl.cli.utils import (
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
@@ -176,7 +179,7 @@ def train(
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
@@ -333,6 +336,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize
do_quantize(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))

View File

@@ -1,20 +1,18 @@
"""CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_merge_lora(*, cfg: DictDefault) -> None:
@@ -68,12 +66,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(
config,

View File

@@ -1,7 +1,6 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json
import logging
import os
import shutil
from pathlib import Path
@@ -11,7 +10,6 @@ import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
import transformers
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@@ -24,11 +22,11 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
@@ -197,11 +195,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -1,6 +1,5 @@
"""CLI to run preprocessing of a dataset."""
import logging
import warnings
from pathlib import Path
from typing import Union
@@ -20,9 +19,10 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:

View File

@@ -0,0 +1,90 @@
"""
CLI to post-training quantize a model using torchao
"""
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = get_logger(__name__)
def do_quantize(
config: Union[Path, str],
cli_args: dict,
):
"""
Quantizes a model's model's weights
Args:
config (Union[Path, str]): The path to the config file
cli_args (dict): Additional command-line arguments
"""
print_axolotl_text_art()
cfg = load_cfg(config)
if cfg.qat and cfg.quantization:
raise ValueError(
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
)
if cfg.qat:
quantize_cfg = cfg.qat
elif cfg.quantization:
quantize_cfg = cfg.quantization
else:
raise ValueError(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
quantize_embedding = (
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
f"\tweight_dtype: {weight_dtype}\n"
f"\tactivation_dtype: {activation_dtype}\n"
f"\tgroup_size: {group_size}\n"
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model."""
import gc
import logging
import os
from pathlib import Path
from typing import Union
@@ -22,8 +21,6 @@ from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""

View File

@@ -4,7 +4,6 @@ import concurrent.futures
import dataclasses
import hashlib
import json
import logging
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -20,10 +19,12 @@ from transformers import (
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None):
@@ -304,8 +305,8 @@ def load_model_and_tokenizer(
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Helper function for loading a model, tokenizer, and processor specified in the
given `axolotl` config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -318,7 +319,8 @@ def load_model_and_tokenizer(
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:

View File

@@ -2,14 +2,27 @@
CLI to start the vllm server for online RL
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg
@dataclass
class AxolotlScriptArguments(ScriptArguments):
"""
Additional arguments for the VLLM server
"""
reasoning_parser: str = field(default="", kw_only=True)
enable_reasoning: bool | None = field(default=None, kw_only=True)
def do_vllm_serve(
config: Union[Path, str],
cli_args: dict,
@@ -24,6 +37,7 @@ def do_vllm_serve(
Returns:
process_id: the process id of the started VLLM server
"""
patch_vllm_worker()
cfg = load_cfg(config)
model = cfg.base_model
@@ -43,9 +57,16 @@ def do_vllm_serve(
enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
)
reasoning_parser = (
cli_args.get("reasoning_parser") or cfg.vllm.reasoning_parser or ""
)
enable_reasoning = (
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
vllm_script_args = ScriptArguments(
model,
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,
host=host,
port=port,
@@ -53,5 +74,67 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker

View File

@@ -1,6 +1,5 @@
"""Dataset loading utilities."""
import logging
import math
import random
from dataclasses import dataclass
@@ -10,14 +9,15 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
@dataclass

View File

@@ -0,0 +1,6 @@
"""Trainer builder classes"""
from .causal import HFCausalTrainerBuilder
from .rl import HFRLTrainerBuilder
__all__ = ["HFCausalTrainerBuilder", "HFRLTrainerBuilder"]

View File

@@ -0,0 +1,503 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for trainer builder"""
import abc
import importlib
import logging
import sys
from abc import abstractmethod
from contextlib import suppress
from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
self._train_dataset = None
self._eval_dataset = None
self._model_ref = None
self._peft_config = None
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@property
def peft_config(self):
return self._peft_config
@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config
@abstractmethod
def build(self, total_num_steps):
pass
def get_callbacks(self) -> list[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
callbacks.append(GPUStatsCallback(cfg=self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_ratio = self.cfg.warmup_ratio
elif total_num_steps:
warmup_steps = min(int(0.03 * total_num_steps), 100)
else:
warmup_ratio = 0.03
if warmup_steps == 1:
warmup_steps = 2
if self.cfg.logging_steps is not None:
training_args_kwargs["logging_steps"] = self.cfg.logging_steps
else:
training_args_kwargs["logging_steps"] = (
500 # transformers defaults to 500
if not total_num_steps
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine"
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
else:
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
def _configure_optimizer(self, training_args_kwargs: dict, trainer_kwargs: dict):
def _configure_custom_optimizer(
training_args_kwargs: dict, trainer_kwargs: dict
):
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_args_kwargs["learning_rate"],
"weight_decay": training_args_kwargs["weight_decay"],
}
# Adam-specific kwargs
adam_kwargs: dict = {}
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
"adam_beta2"
):
adam_kwargs["betas"] = (
training_args_kwargs.get("adam_beta1"),
training_args_kwargs.get("adam_beta2"),
)
if training_args_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
beta3 = training_args_kwargs.get("adam_beta3", 0.9999)
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
# Note: This is not used in training_args_kwargs, but in trainer_kwargs
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
_configure_custom_optimizer(training_args_kwargs, trainer_kwargs)
else:
# Use transformers' optimizer
training_args_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_args_kwargs["optim_args"] = optim_args
if (
self.cfg.optimizer == "adamw_anyprecision"
and Path(self.cfg.torchdistx_path).exists()
):
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
def _configure_hub_parameters(self, training_args_kwargs: dict):
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
# eval_strategy and eval_steps
if not self.eval_dataset or self.cfg.val_set_size == 0:
# do not eval if no eval_dataset or val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
def _configure_reporting(self, training_args_kwargs: dict):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_args_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_args_kwargs["run_name"] = None
def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
def _set_base_training_args(
self, total_num_steps
) -> tuple[dict[str, Any], dict[str, Any]]:
training_args_kwargs: dict[str, Any] = {}
trainer_kwargs: dict[str, Any] = {}
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)
self._configure_precision_settings(training_args_kwargs)
self._configure_save_and_eval_strategy(training_args_kwargs)
self._configure_gradient_checkpointing(training_args_kwargs)
# set arg into trainer_args_kwargs with same name if value not None
for arg in [
# optim/scheduler
"adam_beta1",
"adam_beta2",
"adam_beta3",
"adam_epsilon",
"adam_epsilon2",
"cosine_min_lr_ratio",
"cosine_constant_lr_ratio",
"optim_target_modules",
# trainer
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"gradient_accumulation_steps",
"learning_rate",
"embedding_lr",
"embedding_lr_scale",
"lr_groups",
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
self._configure_reporting(training_args_kwargs)
self._configure_hub_parameters(training_args_kwargs)
self._configure_scheduler(training_args_kwargs)
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
self._configure_torch_compile(training_args_kwargs)
return training_args_kwargs, trainer_kwargs

View File

@@ -0,0 +1,489 @@
"""Builder for causal trainers"""
import inspect
import math
import os
from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlMambaTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(EvalFirstStepCallback())
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
# TODO: check if can move to base class
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = (
self.cfg.lr_quadratic_warmup
)
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs["dataloader_drop_last"] = (
self.cfg.dataloader_drop_last
)
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = (
self.cfg.metric_for_best_model
)
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = (
self.cfg.ddp_broadcast_buffers
)
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs["auto_find_batch_size"] = (
self.cfg.auto_find_batch_size
)
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not self.cfg.flash_attention
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = (
self.cfg.lisa_step_interval
)
training_arguments_kwargs["lisa_layers_attribute"] = (
self.cfg.lisa_layers_attribute
)
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs["neftune_noise_alpha"] = (
self.cfg.neftune_noise_alpha
)
if self.cfg.accelerator_config:
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
multiple = 64
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
use_batch_sampler_collator = False
if is_eval is False and training_args.sample_packing:
use_batch_sampler_collator = True
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True
collator: Type[
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
]
]
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama
if (
self.cfg.flex_attention
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
)
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
return collator(
*collator_args,
**kwargs,
)

View File

@@ -0,0 +1,246 @@
"""Builder for RLHF trainers"""
import inspect
from pathlib import Path
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
LOG = get_logger(__name__)
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def _get_trainer_cls(self, trainer_kwargs: dict):
"""
Returns trainer_cls and trainer_cls_args
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None:
return trainer_cls, trainer_cls_args
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args.append(self.model_ref)
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):
"""
Returns training_args and trainer_kwargs
"""
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
# only rlhf
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
# Not compatible with IPO
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args, trainer_kwargs
def build(self, total_num_steps):
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# TODO: build PPOConfig
raise NotImplementedError("PPO trainer builder is not implemented yet.")

View File

@@ -156,7 +156,6 @@ class Messages(BaseModel):
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:

File diff suppressed because it is too large Load Diff

View File

@@ -4,11 +4,10 @@
from __future__ import annotations
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Literal
from functools import partial, wraps
from typing import Callable, Literal, Optional
import datasets
import torch
@@ -29,20 +28,18 @@ from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -68,10 +65,6 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -120,10 +113,12 @@ class AxolotlTrainer(
drop_last=True,
)
def _get_train_sampler(self) -> Sampler | None:
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
"""
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
@@ -132,9 +127,7 @@ class AxolotlTrainer(
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
if self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
@@ -146,31 +139,26 @@ class AxolotlTrainer(
if use_sample_packing:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
dataset=train_dataset,
)
return base_sampler
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Helper method to get the sampler for evaluation. Handles sample packing case.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False
)
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
if use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
return super()._get_eval_sampler(eval_dataset)
@@ -184,149 +172,91 @@ class AxolotlTrainer(
return base_sampler
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
def _get_dataloader(
self,
dataset: Dataset,
description: str,
batch_size: int,
sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
is_training: bool = False,
dataloader_key: Optional[str] = None,
) -> DataLoader:
"""Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
params = {
data_collator = self.data_collator if is_training else self.eval_data_collator
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if isinstance(dataset, datasets.Dataset):
if is_training:
if not self.args.sample_packing or self.args.pretraining:
dataset = self._remove_unused_columns(
dataset, description="training"
)
elif (
not is_training
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
batch_size = (
batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
else:
dataset = self._remove_unused_columns(dataset, description=description)
else:
data_collator = self._get_collator_with_removed_columns(
self.data_collator, description=description
)
dataloader_params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
del dataloader_params["drop_last"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if is_training:
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
dataloader = DataLoader(dataset, **dataloader_params)
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
# Accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version for eval dataloaders.
# fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
# fmt: on
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
return dataloader
return super().get_eval_dataloader(eval_dataset)
return self.accelerator.prepare(dataloader)
def _get_bench_sampler(
self, bench_dataset: Dataset

View File

@@ -1,92 +1,41 @@
"""
DPO trainer for axolotl
"""
"""DPO trainer for axolotl"""
import gc
import random
from functools import wraps
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union
import pandas as pd
import torch
import wandb
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
from trl.trainer.utils import log_table_to_comet_experiment
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
):
"""Extend the base DPOTrainer for axolotl helpers."""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
# pylint: disable=duplicate-code
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
# pylint: disable=duplicate-code
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
Overwrite the `push_to_hub` method in order to force-add the tags when pushing
the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub`
for more details.
"""
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
@@ -95,64 +44,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
# Apply the chat template if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features,
@@ -192,69 +83,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
gc.collect()
torch.cuda.empty_cache()
return loss
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(
range(num_samples), k=self.args.eval_batch_size
)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = (
self.generate_from_model_and_ref(self.model, random_batch)
)
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"],
policy_output_decoded,
ref_output_decoded,
)
],
)
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
# Base evaluation
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,
ignore_keys,
metric_key_prefix,
)
return initial_output

View File

@@ -2,7 +2,6 @@
import importlib
import inspect
import logging
from typing import Any
from trl.trainer.grpo_trainer import RewardFunc
@@ -13,9 +12,10 @@ from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class GRPOStrategy:
@@ -69,6 +69,9 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
@@ -106,7 +109,9 @@ class GRPOStrategy:
return grpo_args_kwargs
@classmethod
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
def set_trainer_args(
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
@@ -123,6 +128,7 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
return trainer_kwargs
@classmethod
@@ -132,7 +138,7 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc"]
return ["dataset_num_proc", "max_length"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
@@ -167,4 +173,4 @@ class GRPOStrategy:
LOG.info(
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func
return reward_func_fqn

View File

@@ -12,3 +12,5 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None

View File

@@ -3,7 +3,7 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from contextlib import nullcontext
from functools import partial
from typing import Any
import datasets
@@ -14,7 +14,7 @@ from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
is_peft_available,
)
from datasets import Dataset, IterableDataset
from torch import nn
@@ -30,15 +30,13 @@ from transformers import (
TrainerCallback,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available
from trl.extras.profiling import profiling_context
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -46,67 +44,56 @@ from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
@@ -130,6 +117,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: "PeftConfig | None" = None,
optimizer_cls_and_kwargs: tuple[type, dict] | None = None,
):
# First call the superclass constructor with all arguments
super().__init__(
@@ -143,6 +131,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
)
# Get number of SP groups (number of processes divided by SP degree)
@@ -184,6 +173,13 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
f"the valid values for the number of generations are: {possible_values}."
)
self.sp_group = None
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.local_rank = 0
self.local_world_size = 1
def train(self, *args, **kwargs):
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
@@ -191,6 +187,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
return super().train(*args, **kwargs)
def _get_train_sampler(self) -> Sampler:
effective_batch_size = (
self.args.per_device_train_batch_size

View File

@@ -6,4 +6,3 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,18 +1,17 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.utils.logging import get_logger
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class OptimizerMixin(Trainer):
@@ -199,3 +198,20 @@ class OptimizerMixin(Trainer):
)
return self.optimizer
class OptimizerInitMixin:
"""
Mixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not
accept optimizer_cls_and_kwargs as kwarg in constructor.
"""
def __init__(self, *args, **kwargs):
optimizer_cls_and_kwargs = kwargs.pop("optimizer_cls_and_kwargs", None)
super().__init__(*args, **kwargs)
if (
optimizer_cls_and_kwargs
and self.optimizer_cls_and_kwargs is None
and self.optimizer is None
):
self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs

View File

@@ -6,7 +6,6 @@ See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
@@ -17,7 +16,9 @@ from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class RngLoaderMixin(Trainer):

View File

@@ -1,12 +1,11 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -14,7 +13,7 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class SchedulerMixin(Trainer):
@@ -80,13 +79,15 @@ class SchedulerMixin(Trainer):
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
min_lr=0 if not use_cosine_min_lr else (
self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
@@ -115,9 +116,11 @@ class SchedulerMixin(Trainer):
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
LOG.warning(
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler # type: ignore

View File

@@ -1,87 +0,0 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import torch.distributed as dist
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
)
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
specifically for creating appropriate data samplers.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)

View File

@@ -1,7 +1,5 @@
"""Module for TRL PPO trainer"""
from typing import Literal, Union
import torch
from tqdm import tqdm
from trl import (
@@ -14,6 +12,7 @@ from trl import (
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -75,87 +74,19 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
chosen_rewards
).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
rejected_rewards
).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
reward_accuracies
).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
)
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
)
metrics[f"{prefix}log_odds_ratio"] = (
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
)
metrics[f"{prefix}log_odds_chosen"] = (
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
)
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -163,89 +94,19 @@ class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = (
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
)
metrics[f"{prefix}rewards/rejected"] = (
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
)
metrics[f"{prefix}rewards/accuracies"] = (
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
)
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
.mean()
.item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
)
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
class AxolotlRewardTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
):
"""
Extend the base RewardTrainer for axolotl helpers
"""
@@ -253,7 +114,9 @@ class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

View File

@@ -9,8 +9,6 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass
class AxolotlTrainingMixins:
@@ -166,12 +164,6 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
@@ -216,14 +208,16 @@ class AxolotlTrainingMixins:
},
)
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
ring_attn_func: Optional[RingAttnFunc] = field(
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)

View File

@@ -1,12 +1,13 @@
"""Module containing Dataset functionality"""
import logging
import os
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
@@ -15,7 +16,7 @@ from .prompt_tokenizers import PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset):
@@ -63,6 +64,10 @@ class TokenizedPromptDataset(Dataset):
desc="Strategy Filtering Rows",
)
import ipdb
ipdb.set_trace()
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,

View File

@@ -10,71 +10,83 @@
# License for the specific language governing permissions and limitations under
# the License.
"""
Base class for all plugins.
"""Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
from __future__ import annotations
import collections
import importlib
import logging
from typing import OrderedDict
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
import torch
from peft import PeftModel
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True)
if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
class BasePlugin:
"""
Base class for all plugins. Defines the interface for plugin methods.
"""Base class for all plugins. Defines the interface for plugin methods.
Attributes:
None
A plugin is a reusable, modular, and self-contained piece of code that extends
the functionality of Axolotl. Plugins can be used to integrate third-party models,
modify the training process, or add new features.
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
To create a new plugin, you need to inherit from the BasePlugin class and
implement the required methods.
Note:
Plugin methods include:
- register(cfg): Registers the plugin with the given configuration.
- load_datasets(cfg): Loads and preprocesses the dataset for training.
- pre_model_load(cfg): Performs actions before the model is loaded.
- post_model_build(cfg, model): Performs actions after the model is loaded, but
before LoRA adapters are applied.
- pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
- post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
- post_model_load(cfg, model): Performs actions after the model is loaded,
inclusive of any adapters.
- post_trainer_create(cfg, trainer): Performs actions after the trainer is
created.
- create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
- create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
returns a learning rate scheduler.
- add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
training.
- add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
training.
"""
def __init__(self):
"""
Initializes the BasePlugin.
"""
"""Initializes the BasePlugin."""
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
Args:
cfg: The configuration for the plugin.
"""
def get_input_args(self) -> str | None:
"""
Returns a pydantic model for the plugin's input arguments.
"""
"""Returns a pydantic model for the plugin's input arguments."""
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
"""Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
@@ -84,181 +96,164 @@ class BasePlugin:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions before the model is loaded.
Args:
cfg (dict): The configuration for the plugin.
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions after the model is built/loaded, but before any adapters are applied.
Args:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions before LoRA weights are loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after LoRA weights are loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after the model is loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer.
Args:
cfg: The global axolotl configuration.
Returns:
None
The first non-`None` trainer class returned by a plugin.
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is built/loaded, but before any adapters are applied.
# pylint: disable=unused-argument
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
cfg: The configuration for the plugin.
trainer: The trainer object for training.
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
cfg: The configuration for the plugin.
trainer: The trainer object for training.
Returns:
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
The created optimizer.
"""
# pylint: disable=unused-argument
def create_lr_scheduler(
self, cfg, trainer, optimizer, num_training_steps
) -> LRScheduler | None: # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
self,
cfg: DictDefault,
trainer: Trainer,
optimizer: Optimizer,
num_training_steps: int,
) -> LRScheduler | None:
"""Creates and returns a learning rate scheduler.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
cfg: The configuration for the plugin.
trainer: The trainer object for training.
optimizer: The optimizer for training.
num_training_steps: Total number of training steps
Returns:
object (LRScheduler): The created learning rate scheduler.
The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
setup callbacks before creating the trainer.
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
"""Set up callbacks before creating the trainer.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
cfg: The configuration for the plugin.
model: The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
A list of callback functions to be added to the `TrainingArgs`.
"""
return []
# pylint: disable=unused-argument
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
"""Adds callbacks to the trainer after creating the trainer. This is useful for
callbacks that require access to the model or trainer.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
cfg: The configuration for the plugin.
trainer: The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added
A list of callback functions to be added
"""
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
# pylint: disable=unused-argument
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete.
Args:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
cfg: The axolotl configuration.
model: The loaded model.
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions after training is complete and the model is unloaded.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
cfg: The configuration for the plugin.
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""
Loads a plugin based on the given plugin name.
"""Loads a plugin based on the given plugin name.
The plugin name should be in the format "module_name.class_name".
This function splits the plugin name into module and class, imports the module,
retrieves the class from the module, and creates an instance of the class.
The plugin name should be in the format "module_name.class_name". This function
splits the plugin name into module and class, imports the module, retrieves the
class from the module, and creates an instance of the class.
Parameters:
plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name".
Args:
plugin_name: The name of the plugin to be loaded. The name should be in the
format "module_name.class_name".
Returns:
BasePlugin: An instance of the loaded plugin.
An instance of the loaded plugin.
Raises:
ImportError: If the plugin module cannot be imported.
ImportError: If the plugin module cannot be imported.
"""
# split the plugin name into module and class
module_name, class_name = plugin_name.rsplit(".", 1)
@@ -284,28 +279,26 @@ def load_plugin(plugin_name: str) -> BasePlugin:
class PluginManager:
"""
The PluginManager class is responsible for loading and managing plugins.
It should be a singleton so it can be accessed from anywhere in the codebase.
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
Attributes:
plugins (List[BasePlugin]): A list of loaded plugins.
plugins: A list of loaded plugins.
Methods:
get_instance(): Static method to get the singleton instance of PluginManager.
register(plugin_name: str): Registers a new plugin by its name.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
Note:
Key methods include:
- get_instance(): Static method to get the singleton instance of `PluginManager`.
- register(plugin_name: str): Registers a new plugin by its name.
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
_cfg = None
_instance: PluginManager | None = None
_cfg: DictDefault | None = None
def __new__(cls):
"""
Creates a new instance of PluginManager if it doesn't exist yet.
"""
"""Creates a new instance of PluginManager if it doesn't exist yet."""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: OrderedDict[str, BasePlugin] = (
@@ -315,9 +308,8 @@ class PluginManager:
@staticmethod
def get_instance() -> "PluginManager":
"""
Returns the singleton instance of PluginManager.
If the instance doesn't exist, it creates a new one.
"""Returns the singleton instance of PluginManager. If the instance doesn't
exist, it creates a new one.
"""
if PluginManager._instance is None:
PluginManager()
@@ -332,32 +324,27 @@ class PluginManager:
self._cfg = cfg
def register(self, plugin_name: str):
"""
Registers a new plugin by its name.
"""Registers a new plugin by its name.
Parameters:
plugin_name (str): The name of the plugin to be registered.
Returns:
None
Args:
plugin_name: The name of the plugin to be registered.
Raises:
ImportError: If the plugin module cannot be imported.
ImportError: If the plugin module cannot be imported.
"""
try:
logging.info(f"Attempting to load plugin: {plugin_name}")
LOG.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}")
LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
LOG.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self):
"""
Returns a list of Pydantic classes for all registered plugins' input arguments.'
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
Returns:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins.values():
@@ -366,16 +353,17 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
def load_datasets(self, cfg, preprocess: bool = False):
"""
Calls the load_datasets method of each registered plugin.
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
"""Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess : Whether this is preprocess step of the datasets.
preprocess: Whether this is preprocess step of the datasets.
Returns:
dataset_meta: The dataset metadata loaded from all registered plugins.
The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
@@ -387,83 +375,66 @@ class PluginManager:
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
def pre_model_load(self, cfg: DictDefault):
"""Calls the pre_model_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
None
Args:
cfg: The configuration for the plugins.
"""
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_build(self, cfg, model):
"""
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
but before any adapters have been applied.
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Calls the `post_model_build` method of all registered plugins after the
model has been built / loaded, but before any adapters have been applied.
Args:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_build(cfg, model)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins after the model has been loaded
inclusive of any adapters
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Calls the `pre_lora_load` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
"""
Calls the pre_lora_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
"""
Calls the post_lora_load method of all registered plugins.
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the `post_lora_load` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def get_trainer_cls(self, cfg):
"""
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the `post_model_load` method of all registered plugins after the model
has been loaded inclusive of any adapters.
Parameters:
cfg (dict): The configuration for the plugins.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Calls the `get_trainer_cls` method of all registered plugins and returns the
first non-`None` trainer class.
Args:
cfg: The configuration for the plugins.
Returns:
object: The trainer class, or None if none was found.
The first non-`None` trainer class returned by a plugin.
"""
for plugin in self.plugins.values():
trainer_cls = plugin.get_trainer_cls(cfg)
@@ -471,29 +442,25 @@ class PluginManager:
return trainer_cls
return None
def post_trainer_create(self, cfg, trainer):
"""
Calls the post_trainer_create method of all registered plugins.
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
None
Args:
cfg: The configuration for the plugins.
trainer: The trainer object for training.
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
def create_optimizer(self, trainer: Trainer) -> Optimizer | None:
"""Calls the `create_optimizer` method of all registered plugins and returns
the first non-`None` optimizer.
Parameters:
trainer (object): The trainer object for training.
Args:
trainer: The trainer object for training.
Returns:
object: The created optimizer, or None if none was found.
The created optimizer, or `None` if none was found.
"""
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(self.cfg, trainer)
@@ -502,17 +469,17 @@ class PluginManager:
return None
def create_lr_scheduler(
self, trainer, optimizer, num_training_steps
self, trainer: Trainer, optimizer: Optimizer, num_training_steps: int
) -> LRScheduler | None:
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
"""Calls the `create_lr_scheduler` method of all registered plugins and returns
the first non-`None` scheduler.
Parameters:
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Args:
trainer: The trainer object for training.
optimizer: The optimizer for training.
Returns:
object: The created learning rate scheduler, or None if none was found.
The created learning rate scheduler, or `None` if not found.
"""
for plugin in self.plugins.values():
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
@@ -525,16 +492,17 @@ class PluginManager:
return scheduler
return None
def add_callbacks_pre_trainer(self, cfg, model):
"""
Calls the add_callbacks_pre_trainer method of all registered plugins.
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
"""Calls the add_callbacks_pre_trainer method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
A list of callback functions to be added to the `TrainingArgs`.
"""
callbacks = []
for plugin in self.plugins.values():
@@ -543,16 +511,17 @@ class PluginManager:
callbacks.extend(plugin_callbacks)
return callbacks
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Calls the add_callbacks_post_trainer method of all registered plugins.
def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
"""Calls the `add_callbacks_post_trainer` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Args:
cfg: The configuration for the plugins.
trainer: The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
A list of callback functions to be added to the `TrainingArgs`.
"""
callbacks = []
for plugin in self.plugins.values():
@@ -561,41 +530,30 @@ class PluginManager:
callbacks.extend(plugin_callbacks)
return callbacks
def post_train(self, cfg, model):
"""
Calls the post_train method of all registered plugins.
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the post_train method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_train(cfg, model)
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
def post_train_unload(self, cfg: DictDefault):
"""Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
Args:
cfg: The configuration for the plugins.
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
"""Base class for factories to create custom optimizers"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
) -> Optimizer | None:
pass

View File

@@ -19,17 +19,16 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
import logging
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
LOG = get_logger(__name__, use_environ=True)
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
@@ -76,10 +75,9 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch,
)
if is_main_process(use_environ=True):
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
# The patch checks model_type internally
cce_patch(cfg.model_config_type)

View File

@@ -15,12 +15,13 @@
"""
Module for handling Cut Cross Entropy input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CutCrossEntropyArgs(BaseModel):

View File

@@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
_CONFIG_FOR_DOC,
COHERE_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
_CONFIG_FOR_DOC,
GEMMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -20,15 +20,11 @@ from torch import nn
from transformers.cache_utils import Cache, HybridCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -170,10 +162,6 @@ def cce_forward(
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -19,15 +19,9 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import (
_CONFIG_FOR_DOC,
LLAMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -16,22 +16,12 @@ from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama4.modeling_llama4 import (
_CONFIG_FOR_DOC,
LLAMA4_INPUTS_DOCSTRING,
Llama4CausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -160,9 +150,6 @@ def cce_forward(
)
@replace_return_docstrings(
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None, # type: ignore

View File

@@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import (
Mistral3CausalLMOutputWithPast,
)
from transformers.models.mistral.modeling_mistral import (
_CONFIG_FOR_DOC,
MISTRAL_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -15,23 +15,14 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import (
MLLAMA_INPUTS_DOCSTRING,
_prepare_cross_attention_mask,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -164,10 +155,6 @@ def cce_forward(
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import (
apply_lce,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
_CONFIG_FOR_DOC,
QWEN2MOE_INPUTS_DOCSTRING,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import (
)
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import (
TransformersModelT,
apply_lce,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
_CONFIG_FOR_DOC,
QWEN3_MOE_INPUTS_DOCSTRING,
KwargsForCausalLM,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -2,15 +2,15 @@
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast")
LOG = get_logger(__name__)
class GrokfastCallbackHandler(TrainerCallback):

View File

@@ -19,16 +19,15 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
LOG = logging.getLogger("axolotl.integrations.liger")
LOG = get_logger(__name__, use_environ=True)
class LigerPlugin(BasePlugin):
@@ -85,10 +84,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
if is_main_process(use_environ=True):
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -124,9 +120,9 @@ class LigerPlugin(BasePlugin):
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
@@ -175,7 +171,17 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite
apply_liger_kernel_to_granite(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
else:
logging.warning(
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -15,12 +15,13 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class LigerArgs(BaseModel):

View File

@@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
# @replace_return_docstrings(
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
# )
def lce_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.jamba.modeling_jamba import (
_CONFIG_FOR_DOC,
JAMBA_INPUTS_DOCSTRING,
HybridMambaAttentionDynamicCache,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -3,7 +3,6 @@ Sparse Finetuning plugin for Axolotl — enables handling of sparse neural netwo
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
@@ -16,11 +15,12 @@ from transformers.trainer_callback import TrainerCallback, TrainerControl, Train
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
LOG = get_logger(__name__)
class LLMCompressorCallbackHandler(TrainerCallback):

View File

@@ -17,14 +17,16 @@ Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
"""
import json
import logging
import requests
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
LOG = get_logger(__name__)
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
unfrozen_parameters = {}
@@ -83,17 +85,17 @@ class SpectrumPlugin(BasePlugin):
except FileNotFoundError:
pass
except Exception as exc: # pylint: disable=broad-exception-caught
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}")
LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data:
try:
snr_data = requests.get(snr_url, timeout=60).json()
except requests.exceptions.RequestException as exc:
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
return
# also catch json parsing errors
except json.JSONDecodeError as exc:
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
return
unfrozen_parameters = _generate_unfrozen_params_yaml(

View File

@@ -1,5 +1,4 @@
"""
Module for definition of GEGLU Triton kernels.
"""Module for definition of GEGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
@@ -12,8 +11,6 @@ import torch
import triton
import triton.language as tl
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
@triton.jit
def _geglu_fwd_kernel(

Some files were not shown because too many files have changed in this diff Show More