Compare commits

..

44 Commits

Author SHA1 Message Date
Wing Lian
3b5a9d1d88 update create_optimizer for updated api 2026-02-19 23:49:32 -05:00
Wing Lian
eb59070040 fix labels 2026-02-19 23:44:46 -05:00
Wing Lian
9722aaf7d8 fix for tokenizers change 2026-02-19 21:52:44 -05:00
Wing Lian
c5d20bbd79 integration branch for transformers#44041 2026-02-19 18:34:13 -05:00
NanoCode012
7fbedbd300 fix(doc): add limitation for unfrozen_parameters (#3416) 2026-02-19 18:32:26 -05:00
Wing Lian
145ffc9be1 upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0

* upgrade trl for parity

* handle trl api changes

* orpo doesn't have max_prompt_len to check anymore

* cpoconfig doesn't take max_prompt_length and fix cpu offload

* slow fsdp1 test

* triton min 3.4.0 and liger to 0.7.0

* use transformers main for now for zero3 fix

* handle group_by_length change

* fix changes upstream

* mark skip flaky test

* use transformers latest release 5.2.0
2026-02-19 18:27:27 -05:00
NanoCode012
4f1b5ad29f fix: clarify how to use lm_eval plugin (#3404) [skip ci] 2026-02-15 07:52:30 -05:00
NanoCode012
d6a2532dd7 feat(doc): clarify how to use scattermoe (#3408) [skip ci]
* feat(doc): clarify how to use scattermoe

* chore: fix wording
2026-02-15 07:51:28 -05:00
Wing Lian
5eb265513c fix generic patch for cce (#3405) 2026-02-12 08:58:04 -05:00
NanoCode012
06ac407b92 feat: improve telemetry log (#3398)
* fix: redact trackio and data_files

* fix: add new orgs to whitelist

* feat: add run id to logs for users to easily share

* fix: update to add more metrics

* fix: add missed experiment tracker

* chore: formatting in main
2026-02-10 23:01:34 +07:00
NanoCode012
4e22cf0651 fix: remove telemetry warning (#3397) [skip ci] 2026-02-10 23:01:16 +07:00
VED
a4ee56c315 fix: set rollout in GRPO training_kwargs (#3392) 2026-02-10 18:06:15 +07:00
NanoCode012
c67cbcb0f5 fix: ignore add_special_tokens and use test mode for generation for mistral tokenizer (#3396) [skip ci]
* fix: ignore add_special_tokens and use test mode for generation

* fix: incorrectly setting kwarg
2026-02-10 18:03:26 +07:00
NanoCode012
a2da852576 fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code

* chore: re-order model guides
2026-02-10 17:58:40 +07:00
madScientist10
37e9da7a53 add hub_revision support for specifying branch when pushing checkpoints (#3387) [skip ci] 2026-02-10 17:53:09 +07:00
NanoCode012
ed7105dba7 fix: GRPO config not accept max_prompt_length (#3390) [skip ci] 2026-02-10 17:52:09 +07:00
NanoCode012
b6d3653f74 feat: add step3p5 for cce (#3384) [skip ci]
* feat: add step3p5 for cce

* chore: reorder model
2026-02-10 17:51:43 +07:00
NanoCode012
fcc4cfdb63 feat: add sageattention (#2823) [skip ci]
* feat: add sageattention

* feat: call path on pre model load

* fix: patch to use register to correct var

* fix: add strict check import at start

* chore: fix comments

* chore: refactor

* feat: add capability check

* fix: missed underscore

* fix: let sageattention use FA backend in transformers

* feat: update sage attention for attention mask and position ids

* feat: allow sample packing but add warning without packing

* fix: loss hitting 0 with packing and attention mask note

* feat: downcast embeds if sage attention too

* feat: add config validation

* feat: add attention docs

* chore: docs
2026-02-10 17:49:21 +07:00
VED
97a4f28511 fix: saving state dict and eval for Context Parallel (#3382) [skip ci]
* clone state_dict if none

* patch calculating  eval loss for cp
2026-02-10 17:47:26 +07:00
VED
86a5803212 train_per_sec_per_gpu metric (#3364) [skip ci]
* fix token count

* guard for none n zero
2026-02-10 17:44:55 +07:00
tgoab
530a0c0bf0 Changes from dataset_processes to dataset_num_proc (#3352) [skip ci]
* changes from dataset_processes to dataset_num_proc

* deprecation message improved

---------

Co-authored-by: Juliana Nieto Cárdenas <jnietoca@purdue.edu>
2026-02-10 17:44:17 +07:00
VED
0343a72cc9 add glm support + patch (#3329) [skip ci]
* add glm support + patch

* lint

* lint

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/processing_strategies.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patch removed

* lint

* lint2

* docs + rename

* rmv moe

* docs

* removed processor

* sdpa T_T"

* ddp_find_unused_parameters: true

* muti gpu yaml tested both

* muti gpu yaml tested both

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* rmv text only section + v5 comments

* rename

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-02-10 17:43:53 +07:00
Wing Lian
236dad3bb7 set 0.15.0.dev0 version (#3380) 2026-01-30 21:28:01 -05:00
Wing Lian
be00978bc2 tag for v0.14.0 release (#3379)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-30 14:10:27 -05:00
Wing Lian
3738978394 Add support for batched_mm, grouped_mm and scattermoe for MoE models (#3377)
* kernels plugin for moe for v5

* add support for native batched_mm or grouped_mm
2026-01-29 14:25:47 -05:00
Wing Lian
6132a30cda handle warnings from v5 upgrade (#3376) 2026-01-28 06:45:01 -05:00
NanoCode012
3dd86d35b8 feat: add new cce support for glm series and exaone4 (#3373) [skip ci] 2026-01-28 06:44:44 -05:00
salman
dd9ebaeba1 EAFT (#3366) [skip ci]
* wip eaft

* fix eaft loss fn

* adding ref

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
2026-01-28 06:44:15 -05:00
Wing Lian
fc4e37920b transformers v5 upgrade (#3272)
* Prepare for transformers v5 upgrade

* fix hf cli

* update for hf hub changes

* fix tokenizer apply_chat_template args

* remap include_tokens_per_second

* fix tps

* handle migration for warmup

* use latest hf hub

* Fix scan -> ls

* fix import

* fix for renaming of mistral common tokenizer -> backend

* update for fixed tokenziation for llama

* Skip phi35 tests for now

* remove mistral patch fixed upstream in huggingface/transformers#41439

* use namespacing for patch

* don't rely on sdist for e2e tests for now

* run modal ci without waiting too

* Fix dep for ci

* fix imports

* Fix fp8 check

* fsdp2 fixes

* fix version handling

* update fsdp version tests for new v5 behavior

* Fail multigpu tests after 3 failures

* skip known v5 broken tests for now and cleanup

* bump deps

* unmark skipped test

* re-enable test_fsdp_qlora_prequant_packed test

* increase multigpu ci timeout

* skip broken gemma3 test

* reduce timout back to original 120min now that the hanging test is skipped

* fix for un-necessary collator for pretraining with bsz=1

* fix: safe_serialization deprecated in transformers v5 rc01 (#3318)

* torch_dtype deprecated

* load model in float32 for consistency with tests

* revert some test fixtures back

* use hf cache ls instead of scan

* don't strip fsdp_version

more fdsp_Version fixes for v5
fix version in fsdp_config
fix aliasing
fix fsdp_version check
check fsdp_version is 2 in both places

* Transformers v5 rc2 (#3347)

* bump dep

* use latest fbgemm, grab model config as part of fixture, un-skip test

* import AutoConfig

* don't need more problematic autoconfig when specifying config.json manually

* add fixtures for argilla ultrafeedback datasets

* download phi4-reasoning

* fix arg

* update tests for phi fast tokenizer changes

* use explicit model types for gemma3

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>

* fix: AutoModelForVision2Seq -> AutoModelForImageTextToText

* chore: remove duplicate

* fix: attempt fix gemma3 text mode

* chore: lint

* ga release of v5

* need property setter for name_or_path for mistral tokenizer

* vllm not compatible with transformers v5

* setter for chat_template w mistral too

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2026-01-27 17:08:24 -05:00
Wing Lian
a531e9d946 upgrade vllm to v0.14.0 (#3345) 2026-01-21 20:00:18 -05:00
Wing Lian
04328aeb97 cu129 targets for ci builds (#3369)
* cu129 targets for ci builds

* remove copy-paste is_latest
2026-01-21 17:24:44 -05:00
VED
d0d26d5064 feat: Add GDPO Support (#3353)
* gdpo support - test left

* lint

* fixxes for vllm serv

* test advantages

* docss

* lint

* lint =

* gdpo simple + lint

* lint nit

* example

* lint

* trl 0.27.0

* blocklist

* test assert rmv

* add validation check for GDPO + sum_then_normalize

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-21 17:22:45 -05:00
Wing Lian
8623dd8a72 strip only starting 'v' char; e.g don't strip from '.dev' (#3368) [skip ci] 2026-01-21 14:19:03 -05:00
Wing Lian
8cd75cff9f use cuda 12.9.1 and add python 3.12 to base images (#3367) 2026-01-21 13:34:14 -05:00
Wing Lian
8ab9d9ea88 Version dev (#3365) 2026-01-20 22:58:29 -05:00
Wing Lian
6e42def14b set version to v0.13.1 (#3363)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-20 08:58:32 -05:00
Wing Lian
c413480b35 upgrade transformers to 4.57.6 and peft to 0.17.1 and datasets to 4.5.0 (#3361) 2026-01-16 11:48:50 -05:00
Wing Lian
8f25124269 upgrade transformers to 4.57.5 (#3358)
* upgrade transformers to 4.57.5

* explicitly set versions for fbgemm-gpu

* handle index url for cuda version

* explicitly set cu version for fbgemm deps, skip for 130

* cu suffix not needed on version if using whl subpath
2026-01-16 11:17:43 -05:00
Wing Lian
790df757cb don't install xformers in for arm64 (#3359)
* install xformers in the base docker image

* install numba and numpy first

* set CUDA_HOME for xformers install

* Set cuda  home env

* don't install xformers by default on aarch64/arm64
2026-01-16 09:02:37 -05:00
Wing Lian
d282f32481 don't install deepspeed in arm64 images (#3357) 2026-01-14 12:03:55 -05:00
Wing Lian
6331e4a130 fix amd64 and set 2.9.1 as latest cloud image (#3356) 2026-01-14 11:56:36 -05:00
salman
1410e4474e update PR template (#3349) [skip ci] 2026-01-14 09:39:21 -05:00
Wing Lian
dc77b5bf42 fix arm64 builds (#3355)
* fix syntax  for secrets in gha yaml

* setup env for uv too

* arm64 for base  uv too

* don't build causal-conv1d or mamba for arm64 and use arm64 wheels

* fix dockerfile syntax

* fix shell syntax
2026-01-14 09:38:48 -05:00
NanoCode012
359b7ad85e fix: gemma3_text model loading vision config (#3354)
* fix: gemma3-text mode loading vision config

* fix: improve defaults to use lora kernels
2026-01-13 09:49:23 -05:00
157 changed files with 2613 additions and 618 deletions

View File

@@ -15,6 +15,11 @@
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate)
## Types of changes

View File

@@ -21,6 +21,8 @@ jobs:
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -32,6 +34,7 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -39,6 +42,7 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -46,6 +50,15 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -53,6 +66,15 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -79,7 +101,7 @@ jobs:
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -90,7 +112,7 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: linux/amd64,linux/arm64
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -105,6 +127,8 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -116,6 +140,7 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -123,6 +148,7 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -130,6 +156,15 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -137,6 +172,15 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -148,6 +192,7 @@ jobs:
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -158,6 +203,7 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,22 +20,32 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
platforms: "linux/amd64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -61,7 +71,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
@@ -88,22 +98,32 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
platforms: "linux/amd64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -128,7 +148,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
@@ -149,11 +169,11 @@ jobs:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.9.1
axolotl_extras:
is_latest:
- cuda: 128
cuda_version: 12.8.1
is_latest: true
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:

View File

@@ -35,21 +35,26 @@ jobs:
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: fbgemm-gpu
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
nightly_build: "true"
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: fbgemm-gpu
axolotl_extras:
# axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -71,8 +76,8 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run -m cicd.multigpu

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging==23.2
pip3 install wheel packaging==26.0
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
@@ -48,9 +48,9 @@ jobs:
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in setup.py
- name: Update version in VERSION file
run: |
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
- name: Build a source dist
run: |

View File

@@ -48,7 +48,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |

View File

@@ -54,8 +54,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -82,7 +87,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -110,10 +115,10 @@ jobs:
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache
run: hf cache scan
run: hf cache ls
- name: Run tests
run: |
@@ -127,7 +132,7 @@ jobs:
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache
run: hf cache scan
run: hf cache ls
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -144,8 +149,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -172,7 +182,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install PyTorch
run: |
@@ -200,7 +210,7 @@ jobs:
axolotl --help
- name: Show HF cache
run: hf cache scan
run: hf cache ls
- name: Run tests
run: |
@@ -209,10 +219,10 @@ jobs:
pytest -v --durations=10 tests/cli/
- name: Show HF cache
run: hf cache scan
run: hf cache ls
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
needs: [pre-commit]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
@@ -248,16 +258,16 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -359,9 +369,9 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,7 +39,6 @@
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
@@ -107,7 +106,7 @@
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set
# dataset_num_proc: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
@@ -224,9 +223,6 @@
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
@@ -352,8 +348,6 @@
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
@@ -412,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES}
dataset_num_proc: ${DATASET_NUM_PROC}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
@@ -512,7 +506,6 @@ profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}

View File

@@ -88,7 +88,7 @@ Features:
#### Using pip
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.15.0.dev0

View File

@@ -251,7 +251,6 @@ website:
- docs/models/olmo3.qmd
- docs/models/trinity.qmd
- docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3"
contents:
- docs/models/ministral3.qmd
@@ -266,6 +265,7 @@ website:
- docs/models/mistral-small.qmd
- docs/models/voxtral.qmd
- docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd
- docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd
@@ -320,6 +320,7 @@ website:
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features"
contents:

View File

@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -17,7 +17,8 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
@@ -27,8 +28,11 @@ df_args = {
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
}
dockerfile_contents = df_template.render(**df_args)

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 --maxfail=4 \
pytest -v --durations=10 -n2 --maxfail=3 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -6,6 +6,7 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -20,13 +21,17 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
python scripts/unsloth_install.py | sh && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -43,7 +43,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip cache purge

View File

@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \

View File

@@ -2,6 +2,7 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
@@ -31,20 +32,35 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
if [ "$TARGETARCH" = "amd64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \
;; \
esac

View File

@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
Download a base model using the Hugging Face CLI:
```bash
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration

140
docs/attention.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
### delinearize-llama4

View File

@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
```
4. (Optional) Login to Hugging Face:
```{.bash}
huggingface-cli login
hf auth login
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -89,6 +89,10 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,6 +19,7 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)
@@ -183,6 +184,18 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}

View File

@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
## RLHF using Axolotl
@@ -720,6 +721,102 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
::: {.callout-tip}
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
:::
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform
```
You can also use GRPO with explicit aggregation control:
```yaml
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behavior
```
#### GDPO vs GRPO
| Aspect | GRPO | GDPO |
|--------|------|------|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
| **Multi-reward** | May collapse advantages | Preserves reward signals |
| **Single reward** | Standard behavior | Equivalent to GRPO |
#### Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
```
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
```
GDPO normalizes each reward independently, preserving their relative differences.
#### Reward Functions
GDPO uses the same reward function format as GRPO:
```python
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewards
```
#### Sequence Parallelism
GDPO supports sequence parallelism for long-context training:
```yaml
rl: gdpo
context_parallel_size: 2
```
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
]
},
{

View File

@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -0,0 +1,77 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/eaft-gemma-3-1b
use_eaft: true
eaft_alpha: 1.0
eaft_k: 20
sequence_len: 1024
sample_packing: false
adapter:
lora_model_dir:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
max_steps: 1000
evaluation_strategy: "no"
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
debug:
deepspeed:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_linear: true
sequence_len: 2048

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_linear: true
sequence_len: 2048

View File

@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
@@ -32,8 +33,8 @@ sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_dropout: 0
lora_target_linear: true
wandb_project:
wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:

View File

@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

44
examples/glm46v/README.md Normal file
View File

@@ -0,0 +1,44 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,53 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,50 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -19,7 +19,6 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true

View File

@@ -0,0 +1,68 @@
base_model: meta-llama/Llama-3.2-1B-Instruct
chat_template: llama3
rl: gdpo
trl:
beta: 0.001
max_completion_length: 128
num_generations: 2
temperature: 0.7
top_p: 0.95
use_vllm: false
multi_objective_aggregation: normalize_then_sum
reward_funcs:
- rwd.format_reward
- rwd.correctness_reward
reward_weights: [1.0, 2.0]
log_completions: true
num_completions_to_print: 3
scale_rewards: true
datasets:
- path: openai/gsm8k
name: main
split: train[:1000]
type: rwd.gsm8k_transform
val_set_size: 0.0
output_dir: ./outputs/llama3-gdpo-out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
max_steps: 100
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
weight_decay: 0.01
warmup_steps: 10
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
flash_attention: true
logging_steps: 1
save_steps: 50
save_safetensors: true
special_tokens:
pad_token: "<|end_of_text|>"
seed: 42

View File

@@ -12,7 +12,6 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora

View File

@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -47,6 +47,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -12,7 +12,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
build-backend = "setuptools.build_meta"
[project]
@@ -24,6 +24,9 @@ Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.dynamic]
version = { file = "VERSION" }
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
@@ -57,3 +60,6 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -2,24 +2,24 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
triton>=3.0.0
triton>=3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.6.4
liger-kernel==0.7.0
# END section
packaging==23.2
huggingface_hub>=0.36.0
peft>=0.18.0
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==4.57.1
transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only
accelerate==1.12.0
datasets==4.4.2
datasets==4.5.0
deepspeed>=0.18.3
trl==0.25.1
trl==0.28.0
hf_xet==1.2.0
kernels==0.11.5
trackio>=0.13.0
typing-extensions>=4.15.0
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
torchao==0.16.0
openenv-core==0.1.0
schedulefree==1.4.1
@@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.8.6
mistral-common==1.8.8

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
)

View File

@@ -1,6 +1,5 @@
"""setup.py for axolotl"""
import ast
import os
import platform
import re
@@ -26,6 +25,7 @@ def parse_requirements(extras_require_map):
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
@@ -62,44 +62,68 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:
extras_require_map["vllm"] = ["vllm==0.14.0"]
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
if install_xformers:
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
if install_xformers:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
if install_xformers:
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
if install_xformers:
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
if install_xformers:
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
@@ -110,15 +134,11 @@ def parse_requirements(extras_require_map):
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
"r",
encoding="utf-8",
) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
version_ = fin.read().strip()
return version_

View File

@@ -1,7 +1,11 @@
"""Axolotl - Train and fine-tune large language models"""
import pkgutil
from importlib.metadata import PackageNotFoundError, version
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.13.0.dev"
try:
__version__ = version("axolotl")
except PackageNotFoundError:
__version__ = "unknown"

View File

@@ -5,6 +5,6 @@ import os
from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
configure_logging()

View File

@@ -44,7 +44,7 @@ def check_user_token() -> bool:
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:

View File

@@ -24,7 +24,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
@@ -42,7 +41,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(

View File

@@ -14,8 +14,6 @@ from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from huggingface_hub import split_torch_state_dict_into_shards
@@ -40,17 +38,15 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
) -> Path:
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
save under `save_path` as `model.safetensors`.
Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
@@ -76,11 +72,7 @@ def _distributed_checkpoint_to_merged_weights(
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
@@ -98,19 +90,12 @@ def _distributed_checkpoint_to_merged_weights(
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
if index is not None:
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@@ -123,13 +108,11 @@ def _distributed_checkpoint_to_merged_weights(
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
Note: this is a CPU-bound process.
@@ -138,8 +121,6 @@ def merge_fsdp_weights(
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
@@ -177,7 +158,7 @@ def merge_fsdp_weights(
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path, safe_serialization
checkpoint_dir_, output_path
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
@@ -210,7 +191,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=output_path,
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()

View File

@@ -102,12 +102,10 @@ def do_quantize(
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
@@ -121,7 +119,7 @@ def do_quantize(
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
model.push_to_hub(hub_model_id)
tokenizer.push_to_hub(hub_model_id)
if processor:
processor.push_to_hub(hub_model_id)

View File

@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps = 0
warmup_steps: int | float = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
@@ -230,6 +230,10 @@ class TrainerBuilderBase(abc.ABC):
else:
warmup_ratio = 0.03
# transformers v5
if warmup_ratio > 0.0 and warmup_steps == 0:
warmup_steps = warmup_ratio
if warmup_steps == 1:
warmup_steps = 2
@@ -242,7 +246,6 @@ class TrainerBuilderBase(abc.ABC):
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
@@ -406,6 +409,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:
@@ -530,9 +536,7 @@ class TrainerBuilderBase(abc.ABC):
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
@@ -545,6 +549,7 @@ class TrainerBuilderBase(abc.ABC):
arg_map = {
"dion_learning_rate": "dion_lr",
"include_num_input_tokens_seen": "include_tokens_per_second",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:

View File

@@ -246,7 +246,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
@@ -373,6 +374,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_eaft:
from functools import partial
from axolotl.monkeypatch.loss.eaft import eaft_loss
configured_eaft_loss = partial(
eaft_loss,
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
)
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
@@ -437,7 +450,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
self.cfg.micro_batch_size == 1 and is_eval is False
):
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
@@ -52,12 +51,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
@@ -134,19 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length")
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -155,10 +153,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
elif self.cfg.rl is RLType.GRPO:
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
training_args_kwargs.setdefault(
"multi_objective_aggregation", "normalize_then_sum"
)
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig

View File

@@ -25,7 +25,7 @@ from torch.utils.data import (
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -719,6 +719,13 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
@@ -738,43 +745,38 @@ class AxolotlTrainer(
).save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -126,8 +126,10 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
return grpo_args_kwargs
@@ -149,6 +151,8 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs
@@ -159,7 +163,12 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters
def create_optimizer(self):
def create_optimizer(self, model=None):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
opt_model = self.model if model is None else model
if (
not self.optimizer

View File

@@ -1,12 +1,10 @@
"""Module for TRL RL trainers"""
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from trl import RewardTrainer
from trl.experimental.cpo import CPOTrainer
from trl.experimental.kto import KTOTrainer
from trl.experimental.orpo import ORPOTrainer
from trl.experimental.prm import PRMTrainer
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin

View File

@@ -8,7 +8,11 @@ from dataclasses import dataclass, field
from typing import Optional, Type
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from trl import RewardConfig
from trl.experimental.cpo import CPOConfig
from trl.experimental.kto import KTOConfig
from trl.experimental.orpo import ORPOConfig
from trl.experimental.prm import PRMConfig
from axolotl.integrations.config import merge_training_args

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"
```
## Usage
@@ -36,6 +36,7 @@ plugins:
- cohere
- cohere2
- deepseek_v3
- exaone4
- gemma
- gemma2
- gemma3
@@ -45,13 +46,16 @@ plugins:
- glm
- glm4
- glm4_moe
- glm4_moe_lite
- glm46v
- glm4v
- glm4v_moe
- glm_image
- gpt_oss
- granite
- granitemoe
- granitemoeshared
- granitemoehybrid
- granitemoeshared
- hunyuan_v1_dense
- hunyuan_v1_moe
- internvl
@@ -76,16 +80,17 @@ plugins:
- phi3
- phi4_multimodal
- qwen2
- qwen2_vl
- qwen2_moe
- qwen2_vl
- qwen2_5_vl
- qwen3
- qwen3_moe
- qwen3_next
- qwen3_vl
- qwen3_vl_moe
- qwen3_next
- smollm3
- seed_oss
- smollm3
- step3p5
- voxtral
## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"`'
)
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like(
self,
model_type: str,
model_type_to_patch: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str, remote_model_id: str | None
maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}"
) from e
if model_type not in PATCH_FNS:
if model_type_to_patch not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type
"Setting up generic cce patch for model type: %s", model_type_to_patch
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -0,0 +1,44 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -0,0 +1,7 @@
from .args import KernelsArgs
from .plugin import KernelsPlugin
__all__ = [
"KernelsArgs",
"KernelsPlugin",
]

View File

@@ -0,0 +1,35 @@
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = True
@model_validator(mode="before")
@classmethod
def check_use_kernels(cls, data):
if data.get("use_kernels") is not True:
LOG.warning(
"`use_kernels` must be set to True to use this. Automatically setting it to True."
)
data["use_kernels"] = True
return data
@model_validator(mode="before")
@classmethod
def check_experts_implementation(cls, data):
experts_implementation = data.get("experts_implementation")
if experts_implementation is None:
# transformers may default to batched_mm when unset
data["experts_implementation"] = "eager"
elif experts_implementation != "eager":
LOG.warning(
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
)
data["experts_implementation"] = "eager"
return data

View File

@@ -0,0 +1,61 @@
from kernels import (
LayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
class KernelsPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
def _register_kernels(self):
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
layer_name="HFScatterMoEGatedMLP",
),
Mode.INFERENCE: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
layer_name="HFScatterMoEGatedMLP",
),
},
}
}
)
def _kernelize_model(self, model_type: str):
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls

View File

@@ -12,7 +12,6 @@ def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
@@ -22,7 +21,6 @@ def save_compressed_model(
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
@@ -34,7 +32,6 @@ def save_compressed_model(
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -6,6 +6,12 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
@@ -16,9 +22,50 @@ lm_eval_tasks:
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation
```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path
from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
model=get_model_path(cfg),
):
subprocess.run( # nosec
lm_eval_args,

View File

@@ -13,6 +13,21 @@ import yaml
from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
@@ -108,7 +123,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
model=get_model_path(cfg),
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -26,7 +26,6 @@ from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
@@ -226,6 +225,7 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._configure_experts_implementation()
self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
@@ -233,6 +233,10 @@ class ModelLoader:
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _configure_experts_implementation(self):
if self.cfg.experts_implementation is not None:
self.model.set_experts_implementation(self.cfg.experts_implementation)
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
@@ -334,7 +338,12 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
(
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
(
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
and not self.is_qlora_and_fsdp_enabled
)
or (
@@ -434,7 +443,7 @@ class ModelLoader:
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
self.model_config.model_type, AutoModelForImageTextToText
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
@@ -476,6 +485,7 @@ class ModelLoader:
max_memory = None
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
self.model_kwargs["dtype"] = self.cfg.torch_dtype
is_ds_zero3 = is_deepspeed_zero3_enabled()
@@ -607,6 +617,10 @@ class ModelLoader:
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager"
@@ -670,7 +684,7 @@ class ModelLoader:
Uses the selected loader when provided; otherwise falls back to the auto loader.
"""
loader = model_loader_class or self.auto_model_loader
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
if loader in [AutoModelForCausalLM, AutoModelForImageTextToText]:
model = loader.from_config(
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -788,6 +802,7 @@ class ModelLoader:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
if self.cfg.reinit_weights:
self.model = self._load_model_from_config(model_loader_class)
else:

View File

@@ -10,6 +10,7 @@ from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
@@ -96,6 +97,7 @@ class PatchManager:
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -153,12 +155,9 @@ class PatchManager:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(
self.cfg.chunked_cross_entropy_num_chunks,
use_dft=self.cfg.use_dynamic_finetuning,
)
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
else:
patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)
patch_chunked_ce_loss_fn()
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
@@ -204,6 +203,13 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
@@ -223,13 +229,6 @@ class PatchManager:
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
@@ -502,6 +501,7 @@ class PatchManager:
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference
):
# TODO(MengqingCao): split these patches separately

View File

@@ -31,7 +31,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
tokenization_mistral_common.MistralCommonBackend = HFMistralTokenizer
_patch_mistralcommontokenizer()

View File

@@ -5,6 +5,7 @@ from typing import Type
import addict
import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault
@@ -153,6 +154,9 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -174,8 +178,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels:
# num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try:
model_config = AutoConfig.from_pretrained(
model_config = config_cls.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,

View File

@@ -111,7 +111,6 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -0,0 +1,211 @@
"""
Monkeypatch for SageAttention for use with transformers.
https://github.com/thu-ml/SageAttention/
"""
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
sageattn = None # pylint: disable=invalid-name
sageattn_varlen = None # pylint: disable=invalid-name
def _is_sageattn_available():
"""Determine if SageAttention is available"""
try:
import sageattention # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
if _is_sageattn_available():
# import sageattn here if available
from sageattention import sageattn, sageattn_varlen
def _check_sageattn_imported():
"""Check if SageAttention is imported. Raises an ImportError if not."""
if sageattn is None:
raise ImportError(
"SageAttention is not installed. Please install it from source: "
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
)
def sage_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""
Forward pass for SageAttention compatible with transformers attention interfaces.
https://github.com/thu-ml/SageAttention/
"""
_check_sageattn_imported()
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
raise NotImplementedError(
"SageAttention does not support `output_attentions=True` or `head_mask`."
)
# The base sageattn API does not support dropout.
if dropout > 0.0:
raise NotImplementedError("SageAttention does not support dropout.")
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
# Calculate is_causal following transformers
assert is_causal is not False, "is_causal must be True or None"
is_causal = True
position_ids = kwargs.get("position_ids", None)
query_length = query.shape[2]
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
max_length_q = kwargs.get("max_length_q", None)
max_length_k = kwargs.get("max_length_k", None)
# Sample packing uses position_ids, so we check for it first
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.size(0)
from transformers.modeling_flash_attention_utils import (
prepare_fa2_from_position_ids,
)
if cu_seqlens_q is None or cu_seqlens_k is None:
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query, key, value, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
is_causal=is_causal,
sm_scale=scaling,
smooth_k=False, # reduces loss 0 / nan grad norms
tensor_layout="NHD",
)
attn_output = attn_output_unpad.view(
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
)
elif attention_mask is not None:
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
assert attention_mask.ndim == 2, "Attention mask must be 2D"
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scaling,
tensor_layout="NHD",
)
from flash_attn.bert_padding import pad_input
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Use standard sageattn
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
# which corresponds to SageAttention's "HND" layout.
attn_output = sageattn(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
sm_scale=scaling,
)
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
# Transformers expects (batch, seq_len, heads, head_dim) for the output
# So we need to transpose dimensions 1 and 2
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def patch_sageattn():
"""Patch SageAttention for use with transformers."""
_check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

View File

@@ -59,7 +59,12 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
output = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
torch.autograd.backward(output, dY)
return (
None,

View File

@@ -169,7 +169,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Axolotl could not import attention class for model_type: {model_type}. "
"Please raise an Issue and turn off lora kernels to continue training. "
f"Error: {str(e)}"
) from e

View File

@@ -16,16 +16,10 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
"""
def __init__(
self,
num_output_chunks: int = 8,
ignore_index: int = -100,
use_dft: bool = False,
):
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
self.use_dft = use_dft
def compute_cross_entropy(
self,
@@ -36,30 +30,10 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
ce_loss = F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="none"
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
)
if self.use_dft:
# Compute probabilities and gather the ones corresponding to labels
with torch.no_grad(): # Stop gradient
probs = torch.softmax(logits.float(), dim=-1)
# Create mask for valid tokens (not ignore_index)
valid_mask = labels != self.ignore_index
# Gather probabilities for the correct tokens
label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
# Apply mask to only scale valid tokens
label_probs = label_probs * valid_mask
# Avoid multiplication by 0 for ignored tokens
label_probs = torch.where(
valid_mask, label_probs, torch.ones_like(label_probs)
)
# Scale the loss by the probability (DFT)
ce_loss = ce_loss * label_probs
return ce_loss.sum()
def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor:
@@ -97,20 +71,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
return total_loss / total_elements
def _build_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft)
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor"
)
return loss_fn_ce
def get_causal_lm_loss(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft)
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
def chunked_fix_cross_entropy(
source,
@@ -154,14 +124,10 @@ def get_causal_lm_loss(
return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(
num_output_chunks, ignore_index, use_dft
)
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss

View File

@@ -0,0 +1,51 @@
"""
eaft (entropy-aware focal training) loss implementation
weights examples by entropy approximation from top-k logits
Reference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py
"""
import torch
import torch.nn.functional as F
def eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20):
"""
compute eaft loss with entropy weighting
args:
outputs: model outputs containing logits
labels: target labels for computing loss
num_items_in_batch: for sample packing support
alpha: exponent for entropy weighting (default 1.0)
k: number of top logits for entropy approximation (default 20)
"""
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
vocab_size = shift_logits.size(-1)
shift_logits_view = shift_logits.view(-1, vocab_size)
shift_labels_view = shift_labels.view(-1)
mask = shift_labels_view != -100
with torch.no_grad():
top_k_logits, _ = torch.topk(
shift_logits_view[mask].float(), k=min(k, vocab_size), dim=-1
)
top_k_probs = F.softmax(top_k_logits, dim=-1)
entropy = -(top_k_probs * torch.log(top_k_probs + 1e-10)).sum(dim=-1)
weights = torch.pow(entropy, alpha)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
per_token_loss = loss_fct(shift_logits_view[mask], shift_labels_view[mask])
weighted_loss = per_token_loss * weights
if num_items_in_batch is not None:
loss = weighted_loss.sum() / num_items_in_batch
else:
loss = weighted_loss.mean()
return loss

View File

@@ -1,5 +1,5 @@
"""
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
Monkeypatch to fix inefficient tensor conversion in MistralCommonBackend.apply_chat_template
"""
import importlib
@@ -12,11 +12,11 @@ LOG = get_logger(__name__)
def apply_mistral_tokenizer_image_patch():
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonTokenizer
"""Apply patch to MistralCommonBackend.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonBackend
# Get original source
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
original_source = inspect.getsource(MistralCommonBackend.apply_chat_template)
original_source, _ = detab_code(original_source)
# Define the replacement
@@ -41,7 +41,7 @@ def apply_mistral_tokenizer_image_patch():
)
# Load necessary imports from the module
module_name = MistralCommonTokenizer.__module__
module_name = MistralCommonBackend.__module__
module = importlib.import_module(module_name)
# Detect what needs to be imported
@@ -79,7 +79,7 @@ def apply_mistral_tokenizer_image_patch():
exec(patched_source, globals()) # nosec B102
# Replace the method
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
MistralCommonBackend.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonBackend tensor conversion patch")
else:
LOG.warning("Could not find target code for MistralCommonTokenizer patching")
LOG.warning("Could not find target code for MistralCommonBackend patching")

View File

@@ -155,7 +155,6 @@ class ReLoRACallback(TrainerCallback):
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
@@ -214,7 +213,7 @@ class ReLoRACallback(TrainerCallback):
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
model.model.save_pretrained(checkpoint_folder)
return control

View File

@@ -52,9 +52,15 @@ def patch_prepare_context_parallel_inputs() -> None:
if item in patched_source:
items_to_import.append(item)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
# Use a separate namespace to capture the exec'd function
namespace = {}
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
exec(patched_source, namespace)
# Explicitly get the function from the namespace
axolotl_prepare_context_parallel_inputs = namespace[
"axolotl_prepare_context_parallel_inputs"
]
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)

View File

@@ -28,8 +28,12 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
ORIGINAL_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -14,7 +14,6 @@ from transformers.models.voxtral import VoxtralProcessor
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
LOG = get_logger(__name__)
@@ -430,7 +429,7 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
def __init__(
self,
processor: Mistral3Processor,
processor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
@@ -486,6 +485,58 @@ class InternVLProcessingStrategy(ProcessingStrategy):
return labels
class Glm4vProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for GLM4V and GLM4V-MoE vision models."""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.tokenizer = getattr(processor, "tokenizer", processor)
self.image_token = "<|image|>" # nosec
self.begin_image_token = "<|begin_of_image|>" # nosec
self.end_image_token = "<|end_of_image|>" # nosec
self.video_token = "<|video|>" # nosec
self.begin_video_token = "<|begin_of_video|>" # nosec
self.end_video_token = "<|end_of_video|>" # nosec
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.begin_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_image_token
)
self.end_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_image_token
)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.begin_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_video_token
)
self.end_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_video_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == self.begin_image_token_id] = -100
labels[labels == self.end_image_token_id] = -100
labels[labels == self.video_token_id] = -100
labels[labels == self.begin_video_token_id] = -100
labels[labels == self.end_video_token_id] = -100
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -493,6 +544,8 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,
@@ -500,10 +553,10 @@ def get_processing_strategy(
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type in [None, "tokenizer_default"]:
tokenizer = getattr(processor, "tokenizer", processor)
if hasattr(tokenizer, "chat_template"):
processing_kwargs["chat_template"] = tokenizer.chat_template
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
@@ -532,6 +585,15 @@ def get_processing_strategy(
return Mistral3ProcessingStrategy(
**processing_kwargs,
)
try:
from transformers.models.glm46v.processing_glm46v import Glm46VProcessor
if isinstance(processor, Glm46VProcessor):
return Glm4vProcessingStrategy(
**processing_kwargs,
)
except ImportError:
pass
if isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy(

View File

@@ -150,6 +150,8 @@ class ChatTemplatePrompter(Prompter):
return self.tokenizer.apply_chat_template(
conversation,
tokenize=True,
return_dict=False,
**chat_template_kwargs,
)

View File

@@ -153,13 +153,27 @@ class TelemetryCallback(TrainerCallback):
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, and grad_norm from log history."""
"""Extract last loss, learning_rate, grad_norm, and token metrics from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
return {
"loss": 0,
"ppl": 0,
"learning_rate": 0,
"grad_norm": 0,
"tokens/total": 0,
"tokens/trainable": 0,
"tokens/train_per_sec_per_gpu": 0,
}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"ppl": last_log.get("ppl", 0),
"learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0),
"tokens/total": last_log.get("tokens/total", 0),
"tokens/trainable": last_log.get("tokens/trainable", 0),
"tokens/train_per_sec_per_gpu": last_log.get(
"tokens/train_per_sec_per_gpu", 0
),
}

View File

@@ -155,6 +155,10 @@ def send_errors(func: Callable) -> Callable:
},
)
LOG.error(
f"Error captured in telemetry. Run ID: {telemetry_manager.run_id}"
)
raise
return wrapper

View File

@@ -5,7 +5,6 @@ import importlib
import logging
import os
import platform
import time
import uuid
from pathlib import Path
from typing import Any
@@ -20,21 +19,6 @@ LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_OUT_WARNING_SLEEP_SECONDS = 10
OPT_OUT_WARNING = (
"\nTelemetry is now enabled by default to help improve Axolotl. "
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes
@@ -46,8 +30,8 @@ FIELDS_TO_REDACT = {
"resume_from_checkpoint",
"hub_model_id",
}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir"}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_", "trackio_", "swanlab_"}
PATH_INDICATORS = {"path", "dir", "data_files"}
# pylint: disable=duplicate-code
RELEVANT_PACKAGES = {
@@ -183,11 +167,6 @@ class TelemetryManager:
"false",
"true",
):
# Print opt-out info message for main process only
if is_main_process():
LOG.warning(OPT_OUT_WARNING)
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
return True
# Only rank 0 will send telemetry

View File

@@ -31,3 +31,10 @@ organizations:
- "mistral-community"
- "llava-hf"
- "ByteDance-Seed"
- "ACE-Step"
- "openbmb"
- "MiniMaxAI"
- "stepfun-ai"
- "internlm"
- "katanemo"
- "XiaomiMiMo"

View File

@@ -135,16 +135,13 @@ def setup_reference_model(
return model_ref
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
def setup_signal_handler(cfg: DictDefault, model: PreTrainedModel):
"""
Set up signal handler for graceful termination.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
# ray workers don't have access to this signal
if cfg.local_rank == 0 and not cfg.use_ray:
@@ -152,9 +149,7 @@ def setup_signal_handler(
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
_model.save_pretrained(cfg.output_dir)
cleanup_distributed()
sys.exit(0)
@@ -219,7 +214,6 @@ def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
@@ -228,7 +222,6 @@ def save_trained_model(
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
@@ -283,7 +276,6 @@ def save_trained_model(
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=merged_path,
safe_serialization=True,
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
@@ -330,11 +322,9 @@ def save_trained_model(
pass
elif cfg.local_rank == 0:
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
trainer.model.save_pretrained(cfg.output_dir)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
model.save_pretrained(cfg.output_dir)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
@@ -344,7 +334,6 @@ def save_trained_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
@@ -449,7 +438,6 @@ def handle_untrained_tokens_fix(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
@@ -459,7 +447,6 @@ def handle_untrained_tokens_fix(
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
@@ -483,9 +470,7 @@ def handle_untrained_tokens_fix(
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
model.save_pretrained(str(Path(cfg.output_dir)))
def setup_model_and_trainer(
@@ -582,15 +567,12 @@ def train(
) = setup_model_and_trainer(cfg, dataset_meta)
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
# Additional setup
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
setup_signal_handler(cfg, model, safe_serialization)
setup_signal_handler(cfg, model)
setup_model_card(cfg)
# Execute the training
@@ -602,7 +584,7 @@ def train(
torch.cuda.empty_cache()
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
save_trained_model(cfg, trainer, model)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)

Some files were not shown because too many files have changed in this diff Show More