Compare commits

...

58 Commits

Author SHA1 Message Date
mhenrhcsen
076b6e1e24 Revamp README.md with a fresh layout and enhanced content, including a new introduction, improved visual elements, and detailed sections on features, quick start guide, and comprehensive documentation. This update aims to create a more engaging and informative experience for users, highlighting Axolotl's capabilities in LLM fine-tuning. 2025-06-18 14:18:46 +02:00
mhenrhcsen
b9db0cad1d Enhance README.md with updated layout and content, including a new introduction, improved visual elements, and detailed sections on latest updates, features, quick start guide, and documentation. This update aims to provide a more engaging and informative experience for users. 2025-06-18 14:06:14 +02:00
Wing Lian
da8f6c32b9 update favicon (#2801)
* update favicon

* correct size favicon
2025-06-17 18:09:24 -04:00
Wing Lian
88c0e8d048 release tag (#2799)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-06-17 12:13:27 -04:00
NanoCode012
d8e8cd8558 feat: remove evalfirst callback with built-in trainer arg (#2797) 2025-06-17 12:09:33 -04:00
Wing Lian
ccc94da8ad KD fix w/ online distillation (#2700) [skip ci]
* kd fixes

* fix collator setup

* fix input args

* better handling to drop string fields for kd with raw dataset

* kd trainer has kd temp as part of the init

* drop top_k before softmax

* simplfy and remove zscore

* WIP chunked KD loss with autograd wrapper

* more fixes and liger-type chunked loss

* collator cls for plugins

* remove debugging

* additional plugin collator kwargs, don't scale up kd loss by t^2

* don't need temp arg to distill method

* online kd wip

* add close to comment block

* suport sampling params/max new tokens

* handle when no custom collator is used in plugins

* logsumexp trick:

* fix check

* shift off the first empty token

* fix length of padding

* use max not min

* temp scale kd loss at end

* support for dynamic plugin training args mixins and symmetric kl

* chore: lint

* fix trainer callback base class

* Fix decay

* accept compressed responses for smaller wire payload

* post-rebase lint

* more KD updates

* increase hyperparams_count for gradients for added normalize_topk

* fix to remove attention_mask

* rename vars for consistency

* fix rebase issues

* default to dropping last batch in multipack batch sampler

* improve handling of train len

* init collator_cls_and_kwargs

* explicit drop_last=False when checking for multipack completeness

* use separate v2 loader for kd

* fix kd tests to use subprocess so it picks up kd training args

* default value for kd_beta arg

* use updated dataset for ci

* longer timeout for e2e
2025-06-17 12:09:13 -04:00
Matt Cummins
ba62aa65ee fixed the lora_target_modules syntax (#2793) 2025-06-15 16:47:02 -04:00
NanoCode012
21388cf615 Fix: lora kernel pre-patch applied despite post-patch not applied (#2772)
* fix: do not pre-patch self attention if lora dropout non-zero

* fix: add test to check patch not applied

* fix: test

* fix: test config check

* fix where we check so that tests don't break

* fix: test

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-14 11:54:06 -07:00
NanoCode012
80d5b066ec Fix: adding magistral fsdp config, fixing not eval with test_datasets, handle mllama attention (#2789) [skip ci]
* feat: add fsdp config for magistral

* fix: add mllama self attention handling for lora kernels

* fix: no eval if val_set_size 0 despite having test_datasets

* fix: add note for cce for vlm in newer model
2025-06-14 11:53:43 -07:00
NanoCode012
a3c82e8cbb fix: grpo doc link (#2788) [skip ci] 2025-06-13 12:03:47 -07:00
Wing Lian
b2274d430b support for QAT w RL (DPO) (#2776) 2025-06-13 10:00:35 -04:00
NanoCode012
eac4a61f55 Feat: Add Magistral and mistral-common tokenizer support (#2780) 2025-06-12 19:18:33 -04:00
Wing Lian
ace9287c96 update loss value for flakey e2e test (#2786) [skip ci]
* update loss value for flakey e2e test

* use pytest skip

* parametrize combinations
2025-06-12 18:06:14 -04:00
JZacaroli
f5fbc82f2b Fix logging import in evaluate.py (#2782) (#2783)
* Fix logging import in evaluate.py (#2782)

* chore: lint

---------

Co-authored-by: Joe Zacaroli <jaz@cyberscience.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-12 13:23:31 -04:00
NanoCode012
706c677cad feat(doc): update readme to include changelog and remove matrix (#2775) [skip ci]
* feat(doc): update readme to include changelog and remove matrix

* chore: improve wording

* chore: wording

* Update README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

* chore: address comment remove muon

* chore: address comments

* fix: address final comments

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-06-12 13:23:18 -04:00
Wing Lian
468580d18e limit multipack sampler processes (#2771) [skip ci]
* limit to 16 packing processes

* make num_processes properly reflect configured dataset_processes
2025-06-12 13:22:58 -04:00
salman
3634d8ff9d QAT docfix (#2778) [skip ci]
* nits

* Update docs/qat.qmd

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-06-12 13:22:40 -04:00
Wing Lian
bcc108efc1 build 2.7.1 images too (#2784) [skip ci] 2025-06-12 13:22:20 -04:00
Wing Lian
581dd324cc build base images for torch 2.7.1 (#2764)
* build base images for torch 2.7.1

* fix: update base docker to use torch 2.7.1

* fix: update doc for main base to use 2.7.1

* make sure to install fa2 in base uv too

* use no build isolation for uv+flashattn

* install psutil also for fa2

* longer timeout for flash attn build

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-06-11 17:11:06 -04:00
Dan Saunders
00cda8cc70 Data loader refactor (#2707)
* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
2025-06-10 19:53:07 -04:00
Dan Saunders
52a0452acb magistral small placeholder (#2777) 2025-06-10 13:03:41 -04:00
NanoCode012
83632f71d8 Feat: add tool calling support via tools column (#2774)
* feat: add tool_calling field support

* fix: add tests
2025-06-09 21:42:05 -07:00
Qingyang Wu
92afa4fa27 Fix the bug of position ids padding (#2739) [skip ci]
* Update batching.py: fix the bug of position ids padding

if position ids is padded with a long sequence of zeros, it will cause flash attention to crash

* use alternate calculation for padding position_ids with a range

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-09 21:26:36 -07:00
Wing Lian
dd660c2ed0 handle when unable to save optimizer state when using ao optimizer with FSDP (#2773) [skip ci]
* handle when unable to save optimizer state when using ao optimizer with FSDP1

* improve messaging

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-06-09 21:26:14 -07:00
Wing Lian
09c685fd2c fix worker_init_fn signature handling (#2769) 2025-06-08 23:14:10 -07:00
Wing Lian
7909bfb076 add manual seed for flaky test_geglu_backward test (#2763) [skip ci] 2025-06-05 09:23:17 -07:00
Wing Lian
cb03c765a1 add uv tooling for e2e gpu tests (#2750)
* add uv tooling for e2e gpu tests

* fixes from PR feedback

* simplify check

* fix env var

* make sure to use uv for other install

* use raw_dockerfile_image

* Fix import

* fix args to experimental dockerfile image call

* use updated modal versions
2025-06-05 07:25:06 -07:00
Timofey Klyubin
4440b4a1ce remove unused field for chat_template.default for DPO training (#2755) [skip ci]
* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

* lint and fix tests for new return value

* remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

lint and fix tests for new return value

fix for updated expected fields for dpo

remove unused field for chat_template.default

"messages" field present in final dataset causes issues with DPO
training otherwise

fix test still expecting "messages" field

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-05 07:22:58 -07:00
NanoCode012
e8e45b3441 fix: remove hqq (#2759) [skip ci] 2025-06-05 07:22:23 -07:00
Wing Lian
c67910fa6f bump hf deps (#2735) [skip ci]
* bump hf deps

* upgrade liger-kernel too

* install cce from fork for transformers fix

* fix reference to vocab size in gemma3 patch

* use padding_idx instead of pad_token_id

* remove fixed gemma3 patch

* use updated cce fork

* fix local mllama cce patches w docstring

* add test for multipack with trainer setup and fix trainer for trainer refactor upstream

* bump modal version

* guard for iterable datasetS

* mllama model arch layout changed in latest transformers

* fix batch sampler with drop_last

* fix: address upstream vlm changes for lora

* fix: update references to old lora target path

* fix: remove mllama fa2 patch due to upstream fix

* fix: lora kernel patch path for multimodal models

* fix: removed mllama from quarto

* run test for came optim on 2.6.0+

* fix fsdp2 patch and remove deprecated patch

* make sure to set sequence_parallel_degree for grpo

* Add SP test for GRPO

* add sp to grpo config for trainer

* use reward_funcs as kwarg to grpo trainer

* fix the comprehension for reward funcs

* reward funcs already passed in as args

* init sp_group right before training

* fix check for adding models to SP context

* make sure to pass args to super

* upgrade deepspeed

* use updated trl and add reasoning flags for vllm

* patch the worker

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-06-05 07:20:33 -07:00
NanoCode012
787880215b fix(deepspeed): deepspeed config not being set for z3 (#2754)
* fix(deepspeed): deepspeed config not being set for z3

* fix: comments
2025-06-03 14:27:09 -07:00
NanoCode012
4b1a29c694 feat(modal): update docker tag to use torch2.6 from torch2.5 (#2749) [skip ci] 2025-06-03 14:26:07 -07:00
NanoCode012
d7fa60662e feat: add chat_template kwargs (#2694) [skip ci] 2025-06-03 14:25:26 -07:00
Dan Saunders
1d91d905c9 remove deprecated wandb env var (#2751)
* remove deprecated wandb env var

* remove os.environ wandb setting; unused loggers

* remove os.environ wandb setting; unused loggers
2025-06-03 14:04:15 -07:00
mhenrhcsen
2bf61d8e25 fix abbriviatation spelling error 2025-06-03 21:30:40 +02:00
mhenrhcsen
68788e419e feat: add Group Relative Policy Optimization (GPRO) to RLHF documentation 2025-06-03 21:30:40 +02:00
github-actions[bot]
94219f6ee8 chore: update pre-commit hooks (#2745)
* chore: update pre-commit hooks

* trigger linter when pre commit hooks are updated

* fix type checks from upgraded pre-commit

---------

Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-06-02 15:54:29 -07:00
Wing Lian
ecc719f5c7 add support for base image with uv (#2691) 2025-06-02 12:48:55 -07:00
NanoCode012
d5d0dc5938 fix: suppress non-axolotl logs unless it's warning or higher (#2724)
* fix: increase log level for root loggers and axolotl's

* fix: BasePlugin using wrong logger

* fix: update logger to take name from module

* feat: change logger class to AxolotlLogger to filter non-axolotl infos or below

* fix: change behavior to not disable existing loggers

* fix: update logging to respect correct env

* chore: fix comment

* fix: suppress accelerate log to LOG_LEVEL if not set

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-05-31 12:13:43 +07:00
NanoCode012
5e86c35322 fix(log): remove duplicate merge_lora param (#2742) [skip ci] 2025-05-31 12:13:31 +07:00
NanoCode012
6778856804 Fix: RL base feature parity (#2133)
* feat: add num_proc and load from cache for rl mapping

* fix: refactor sft and rl trainer to set same base args

* feat: add report_to to set run name

* fix: consolidate handling of fp16, bf16, tf32 kwarg

* chore: consolidate eval_strat, loraplus, lr sched, max_length

* fix: deprecate old types

* fix: adding missing Any

* fix: max_steps incorrectly set

* fix: remove unnecessary datacollator kwarg insert and pop

* fix: update default max_steps

* fix: add missing weight_decay handling

* fix: ignore max_length for grpo

* feat: update CI on trainer_builder

* fix: comments

* improve handling of warmup/logging steps

* use transformers default for logging steps, not None

* fix: remove redundant override

* fix: lint

* feat: allow custom optim for rl methods

* fix: duplicate optim setting

* fix(test): set sequence_parallel_degree default in base cfg

* feat: add handling for seed and SP/ring-attn config

* chore: add back return typing from rebase

* fix(test): use RLType directly to skip needing to validate

* feat: split training builder into sub modules

* fix: remove deprecated clause

* chore: add missing config to doc

* fix: update quarto autodoc

* fix: import path for trainer builder and submodules

* fix: remove redundant configs from rebase mistake

* chore: simplify dynamo check

* fix: optimizer_cls_and_kwargs to be passed into trainer_kwargs

* fix: add missing rex from rebase

* fix: move pop optimizer_cls_and_kwargs

* fix: pop optimizer cls in rl too

* fix: leftover bug from rebase

* fix: update handling of trainer_cls in RL

* fix: address pr feedback

* feat: call hook_pre_create_trainer for rl

* chore: lint

* fix: return notimplemented for ppo

* feat: moved torch compile to base and refactor collator setting

* chore: remove unused importlib.util import

* fix: optimizer cls not being popped

* feat: move epoch setting to base

* fix: catch unhandled custom optimizer

* fix: remove duplicate lora plus setting

* chore: refactor if condition

* chore: refactor set_base_training_args into smaller modules

* fix: address TrainerBuilderBase class variables to instance var

* fix: add handling for beta3 and episilon2

* fix: change to pass dict via arg instead of updating dict

* chore: simplify if condition

* fix: force access to lr & weight decay in case not provided to early error

* fix: remove log sweep

* chore: refactor if condition

* fix: address renamed cfg

* fix: improve handling of cosine hyp

* fix: remove unused params

* chore: refactor

* chore: clarify doc safetensors

* fix: update import path to be unified following comments

* fix: duplicate kwargs passed

* feat: return separate trainer_kwargs

* chore: refactor

* chore: refactor based on comments

* chore: refactor based on comments

* fix: move gpustats callback to base

* chore: create trainer_cls_args first based on comments

* fix: ipo label smoothing passed incorrectly

* feat: add optimizer parity for RL methods with test

* feat: add parity for optimizer in RM/PRM and add test

* fix: remove redundant function override for orpo/cpo batch metrics

* fix: improve handling of dpo_label_smoothing and merge issue

* fix: test fixture returning wrong field

* fix: address avoid direct modify fixture

* chore: minor refactor

* Revert "chore: refactor"

This reverts commit 99c8859eb0.

* feat: rename trainer_builder to builders

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-30 11:21:47 +07:00
Wing Lian
ec4ebfd997 Add a few items to faq (#2734)
* Add a few items to faq

* formatting

* chore: lint
2025-05-28 16:20:19 -04:00
Dan Saunders
bde8b5b6bd fix dist state init before deepspeed setup (#2737) 2025-05-28 14:59:57 -04:00
Dan Saunders
2962a398b7 Lora kernels fix (#2732)
* fix lora kernel patching and improve test

* simplification
2025-05-28 10:03:43 -04:00
salman
65c5481120 Rank 0-only logging (#2608)
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-28 14:57:30 +01:00
salman
5fca214108 QAT (#2590)
QAT and quantization w/torchao
2025-05-28 12:35:47 +01:00
NanoCode012
20fda75917 feat(doc): add google analytics to docs (#2708) 2025-05-28 15:51:21 +07:00
NanoCode012
6b6370f4e3 feat(doc): add info on how to use dapo / dr grpo and misc doc fixes (#2673) [skip ci]
* feat(doc): add info on how to use dapo / dr grpo

* chore: add missing config to docs

* fix: missing comment

* fix: add missing scheduler from schema

* chore: refactor lr scheduler docs

* fix: remove log_sweep
2025-05-28 15:51:04 +07:00
mashdragon
add2025253 Fix Mistral chat template (mistral_v7_tekken) (#2710) [skip ci]
Per 4b8dd8aae7 (d2h-482763)
2025-05-28 15:50:47 +07:00
artem
a703560a10 add two checks to handle legacy format interleaved multimodal ds (#2721) [skip ci]
* add two checks to handle legacy format interleaved ds

* fix: add warning about multiple image using legacy format

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-28 15:49:43 +07:00
NOHHYEOB, BAE
4a80d309e8 Add chat templates for command-a and aya-23-8B models (#2731) [skip ci]
* Add chat templates for command-a and aya model

* Fix: isolate for-loop update and remove unintended changes
2025-05-28 15:49:16 +07:00
NanoCode012
e33f225434 feat(doc): note lora kernel incompat with RLHF (#2706) [skip ci]
* feat(doc): note lora kernel incompat with RLHF

* fix: add validation following comments

* chore: fix typo following suggestion
2025-05-28 15:48:40 +07:00
NanoCode012
3e6948be97 Fix(doc): clarify data loading for local datasets and splitting samples (#2726) [skip ci]
* fix(doc): remove incorrect json dataset loading method

* fix(doc): clarify splitting only happens in completion mode

* fix: update local file loading on config doc

* fix: typo
2025-05-28 15:48:22 +07:00
github-actions[bot]
4a8af60d34 chore: update pre-commit hooks (#2729)
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-05-27 11:45:31 -04:00
Dan Saunders
a0941a9271 no need to generate diff file (#2728) 2025-05-27 11:44:06 -04:00
Dan Saunders
5eb01f3df1 Fix quarto (#2717)
* missing modules

* fix quarto complaints
2025-05-23 21:16:51 -04:00
xzuyn
d27c35ac44 Liger GraniteMoE (#2715) 2025-05-23 18:40:43 -04:00
Dan Saunders
a535b68043 update quarto for model loading refactor (#2716)
* update quarto for model loading refactor

* fix desc
2025-05-23 16:28:31 -04:00
263 changed files with 11016 additions and 5203 deletions

View File

@@ -16,8 +16,9 @@ on:
jobs:
build-base:
if: github.repository_owner == 'axolotl-ai-cloud'
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
matrix:
@@ -28,42 +29,50 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: next
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base-nightly"
# # "next" is for release candidates of pytorch
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: next
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-next"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -85,7 +94,60 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDNN_VERSION=${{ matrix.cudnn_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
build-base-uv:
if: github.repository_owner == 'axolotl-ai-cloud'
timeout-minutes: 480
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -9,6 +9,7 @@ on:
- '.github/workflows/*.yml'
- "*.[q]md"
- "examples/**/*.y[a]?ml"
- ".pre-commit-config.yaml"
workflow_dispatch:
jobs:

View File

@@ -29,12 +29,12 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -97,12 +97,12 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -43,7 +43,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
@@ -59,7 +59,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -25,7 +25,6 @@ jobs:
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
@@ -39,11 +38,3 @@ jobs:
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -44,98 +44,6 @@ jobs:
env:
SKIP: no-commit-to-branch
# preload-cache:
# name: Preload HF cache
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python_version: ["3.11"]
# pytorch_version: ["2.6.0"]
# timeout-minutes: 20
#
# env:
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
#
# steps:
# - name: Check out repository code
# uses: actions/checkout@v4
#
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
#
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p /home/runner/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
#
# - name: Setup Python
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python_version }}
# cache: 'pip' # caching pip dependencies
#
# - name: upgrade pip
# run: |
# pip3 install --upgrade pip
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
#
# - name: Install PyTorch
# run: |
# pip3 install torch==${{ matrix.pytorch_version }}
#
# - name: Install dependencies
# run: |
# pip3 show torch
# pip3 install --no-build-isolation -U -e .
# python scripts/unsloth_install.py | sh
# python scripts/cutcrossentropy_install.py | sh
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
#
# - name: Make sure PyTorch version wasn't clobbered
# run: |
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
#
# - name: Ensure axolotl CLI was installed
# run: |
# axolotl --help
#
# - name: Pre-Download dataset fixture
# run: |
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
#
# - name: Run tests
# run: |
# pytest -v tests/conftest.py
#
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v5
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# files: ./coverage.xml
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
# fail_ci_if_error: false
#
# - name: cleanup pip cache
# run: |
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
#
# - name: Save HF cache
# id: hf-cache
# uses: actions/cache/save@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
@@ -144,22 +52,13 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
@@ -222,27 +121,17 @@ jobs:
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
@@ -299,7 +188,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist]
strategy:
@@ -312,6 +201,13 @@ jobs:
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -322,7 +218,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -333,6 +229,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -341,7 +238,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
timeout-minutes: 120
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
@@ -365,13 +262,13 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
steps:
@@ -384,7 +281,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -395,6 +292,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -424,7 +322,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -19,15 +19,15 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.1.2
rev: 7.2.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.6
rev: v3.3.7
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
rev: v1.16.0
hooks:
- id: mypy
additional_dependencies:

View File

@@ -242,16 +242,12 @@
# early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler: # 'one_cycle' | empty for cosine
# lr_scheduler_kwargs:
# # For one_cycle optim
# lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134

275
README.md
View File

@@ -1,152 +1,177 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
<div align="center">
<a href="https://github.com/axolotl-ai-cloud/axolotl">
<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/docs/logo.png" alt="Axolotl Logo" width="250" style="margin-bottom: 20px;"/>
</a>
<h1><span style="color: #4CAF50;">Axolotl: Fine-tune LLMs with Unprecedented Ease & Power!</span> 🚀</h1>
<p style="font-size: 1.1em; color: #555;">Your ultimate toolkit for efficient, scalable, and versatile large language model fine-tuning.</p>
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
<p>
<a href="https://discord.gg/HhrNrHJPRb" target="_blank">
<img src="https://img.shields.io/discord/1070542385153273887?label=Discord&logo=discord&logoColor=white&color=7289DA" alt="Discord Community" style="margin: 5px;">
</a>
<a href="https://docs.axolotl.ai/" target="_blank">
<img src="https://img.shields.io/badge/Documentation-blue?style=flat&logo=readthedocs&logoColor=white" alt="Official Documentation" style="margin: 5px;">
</a>
<a href="https://pypi.org/project/axolotl/" target="_blank">
<img src="https://img.shields.io/pypi/v/axolotl?label=PyPI&logo=pypi&logoColor=white&color=blue" alt="PyPI Package" style="margin: 5px;">
</a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases" target="_blank">
<img src="https://img.shields.io/github/downloads/axolotl-ai-cloud/axolotl/total?label=Downloads&color=green" alt="GitHub Downloads" style="margin: 5px;">
</a>
</p>
<br>
</div>
Axolotl is a tool designed to streamline post-training for various AI models.
Post-training refers to any modifications or additional training performed on
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
techniques. With support for multiple model architectures and training configurations,
Axolotl makes it easy to get started with these techniques.
---
Axolotl is designed to work with YAML config files that contain everything you need to
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
and much more.
<div style="background-color: #f0f8ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #d0e8ff;">
<h2 style="color: #0056b3; text-align: center; margin-top: 0;">🎉 Latest Innovations & Updates!</h2>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/06:</span> Magistral with mistral-common tokenizer support!</strong> Dive into <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral" style="color: #007bff; text-decoration: none;">examples</a> to train your own Magistral models.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/05:</span> Quantization Aware Training (QAT) support!</strong> Explore the <a href="https://docs.axolotl.ai/docs/qat.html" style="color: #007bff; text-decoration: none;">docs</a> to learn more.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/04:</span> Llama 4 support!</strong> See <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4" style="color: #007bff; text-decoration: none;">examples</a> to train Llama 4 with Axolotl's linearized version!</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/03:</span> Sequence Parallelism (SP) support!</strong> Scale your context length. Read the <a href="https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl" style="color: #007bff; text-decoration: none;">blog</a> and <a href="https://docs.axolotl.ai/docs/sequence_parallelism.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/03:</span> (Beta) Fine-tuning Multimodal models!</strong> Check out the <a href="https://docs.axolotl.ai/docs/multimodal.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/02:</span> LoRA optimizations!</strong> Reduce memory and improve speed. Jump into the <a href="https://docs.axolotl.ai/docs/lora_optims.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
<li style="margin-bottom: 10px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/02:</span> GRPO support!</strong> Dive into our <a href="https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm" style="color: #007bff; text-decoration: none;">blog</a> and <a href="https://github.com/axolotl-ai-cloud/grpo_code" style="color: #007bff; text-decoration: none;">GRPO example</a>.</li>
<li style="margin-bottom: 0px; border-left: 4px solid #6495ED; padding-left: 10px;"><strong><span style="color: #2E8B57;">2025/01:</span> Reward Modelling / Process Reward Modelling fine-tuning!</strong> See <a href="https://docs.axolotl.ai/docs/reward_modelling.html" style="color: #007bff; text-decoration: none;">docs</a>.</li>
</ul>
</div>
Features:
<h2 style="color: #FF5733;"><span style="margin-right: 10px;">✨</span> Axolotl Overview: Your LLM Fine-tuning Powerhouse!</h2>
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
<div style="background-color: #fffacd; padding: 20px; border-radius: 10px; margin-bottom: 30px; border: 1px solid #ffd700;">
<p style="font-size: 1.1em; color: #333; text-align: center;">Axolotl is a powerful, flexible, and user-friendly tool designed to supercharge your post-training workflows for a wide range of cutting-edge AI models.</p>
</div>
## 🚀 Quick Start
<div style="display: flex; flex-wrap: wrap; justify-content: space-around; gap: 20px; margin-bottom: 40px;">
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">🤖</span> Broad Model Compatibility</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Train a vast array of models including LLaMA, Mistral, Mixtral, Pythia, and many more.</li>
<li>Fully compatible with HuggingFace transformers causal language models, ensuring wide adoption.</li>
</ul>
</div>
**Requirements**:
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">🔧</span> Diverse Training Methodologies</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Full fine-tuning, LoRA, QLoRA, GPTQ, QAT.</li>
<li>Preference Tuning: DPO, IPO, KTO, ORPO.</li>
<li>Advanced RL: GRPO.</li>
<li>Multimodal and Reward Modelling (RM) / Process Reward Modelling (PRM).</li>
</ul>
</div>
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.4.1
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">⚙️</span> Streamlined Configuration</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Utilize a single, intuitive YAML file across dataset preprocess, training, evaluation, quantization, and inference.</li>
</ul>
</div>
### Installation
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">⚡</span> Cutting-Edge Performance Optimizations</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li><a href="https://docs.axolotl.ai/docs/multipack.html" style="color: #007bff;">Multipacking</a>, <a href="https://github.com/Dao-AILab/flash-attention" style="color: #007bff;">Flash Attention</a>, <a href="https://github.com/facebookresearch/xformers" style="color: #007bff;">Xformers</a>, <a href="https://pytorch.org/blog/flexattention/" style="color: #007bff;">Flex Attention</a>, <a href="https://github.com/linkedin/Liger-Kernel" style="color: #007bff;">Liger Kernel</a>, <a href="https://github.com/apple/ml-cross-entropy/tree/main" style="color: #007bff;">Cut Cross Entropy</a>.</li>
<li>Sequence Parallelism (SP), LoRA optimizations.</li>
<li>Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!</li>
</ul>
</div>
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">📂</span> Flexible Data Handling</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Load datasets from local paths, HuggingFace Hub, and major cloud providers (S3, Azure, GCP, OCI).</li>
</ul>
</div>
<div style="flex: 1 1 45%; background-color: #f9f9f9; padding: 20px; border-radius: 10px; border: 1px solid #eee; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
<h3 style="color: #4CAF50; margin-top: 0;"><span style="margin-right: 5px;">☁️</span> Cloud-Ready & Deployable</h3>
<ul style="list-style-type: disc; padding-left: 20px;">
<li>Official <a href="https://hub.docker.com/u/axolotlai" style="color: #007bff;">Docker images</a> and <a href="https://pypi.org/project/axolotl/" style="color: #007bff;">PyPI packages</a> for seamless integration on cloud platforms and local hardware.</li>
</ul>
</div>
</div>
<h2 style="color: #007bff;"><span style="margin-right: 10px;">🚀</span> Quick Start: Get Fine-tuning in Minutes!</h2>
<div style="background-color: #e6f7ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #cceeff;">
<h3 style="color: #0056b3; margin-top: 0;">Requirements:</h3>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ NVIDIA GPU</span> (Ampere or newer for `bf16` and Flash Attention) or AMD GPU</li>
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ Python 3.11</span></li>
<li style="margin-bottom: 5px;"><span style="color: #333; font-weight: bold;">▶ PyTorch ≥2.5.1</span></li>
</ul>
<h3 style="color: #0056b3;">Installation:</h3>
<pre><code style="background-color: #eef; padding: 15px; border-radius: 8px; display: block; overflow-x: auto;">pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
axolotl fetch deepspeed_configs # OPTIONAL</code></pre>
<p style="font-size: 0.9em; color: #555;">Other installation approaches are described <a href="https://docs.axolotl.ai/docs/installation.html" style="color: #007bff; text-decoration: none;">here</a>.</p>
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
### Your First Fine-tune
```bash
# Fetch axolotl examples
<h3 style="color: #0056b3;">Your First Fine-tune:</h3>
<pre><code style="background-color: #eef; padding: 15px; border-radius: 8px; display: block; overflow-x: auto;"># Fetch axolotl examples
axolotl fetch examples
# Or, specify a custom path
axolotl fetch examples --dest path/to/folder
# Train a model using LoRA
axolotl train examples/llama-3/lora-1b.yml
```
axolotl train examples/llama-3/lora-1b.yml</code></pre>
<p style="text-align: center; font-size: 1.1em; font-weight: bold; margin-top: 20px;">
That's it! Check out our <a href="https://docs.axolotl.ai/docs/getting-started.html" style="background-color: #28a745; color: white; padding: 12px 25px; border-radius: 8px; text-decoration: none; display: inline-block; transition: background-color 0.3s ease;"> Getting Started Guide ➜</a> for a more detailed walkthrough.
</p>
</div>
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
<h2 style="color: #8A2BE2;"><span style="margin-right: 10px;">📚</span> Comprehensive Documentation: Unlock Axolotl's Full Potential</h2>
## ✨ Key Features
<div style="background-color: #f7f0ff; padding: 25px; border-radius: 12px; margin-bottom: 30px; border: 1px solid #e0caff;">
<p style="text-align: center; font-size: 1.1em; color: #333;">Dive deep into Axolotl's capabilities with our extensive documentation:</p>
<ul style="list-style-type: none; padding-left: 0; text-align: center;">
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/installation.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Installation Options</a> - Detailed setup instructions for different environments</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/config.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Configuration Guide</a> - Full configuration options and examples</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/dataset_loading.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Dataset Loading</a> - Loading datasets from various sources</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/dataset-formats/" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Dataset Guide</a> - Supported formats and how to use them</li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multi-gpu.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multi-GPU Training</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multi-node.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multi-Node Training</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/multipack.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> Multipacking</a></li>
<li style="margin-bottom: 10px;"><a href="https://docs.axolotl.ai/docs/api/" style="color: #5d2b99; text-decoration: none; font-weight: bold;"> API Reference</a> - Auto-generated code documentation</li>
<li style="margin-bottom: 0px;"><a href="https://docs.axolotl.ai/docs/faq.html" style="color: #5d2b99; text-decoration: none; font-weight: bold;">❓ FAQ</a> - Frequently asked questions</li>
</ul>
</div>
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
- **Easy Configuration**: Simple YAML files to control your training setup
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
- **Flexible Dataset Handling**: Use various formats and custom datasets
- **Cloud Ready**: Run on cloud platforms or local hardware
<h2 style="color: #FF8C00;"><span style="margin-right: 10px;">🤝</span> Need Help? We're Here for You!</h2>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #7289DA;"></span> Join our vibrant <a href="https://discord.gg/HhrNrHJPRb" style="color: #7289DA; text-decoration: none; font-weight: bold;">Discord community</a> for real-time support and discussions.</li>
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #555;"></span> Explore our <a href="https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/" style="color: #FF8C00; text-decoration: none; font-weight: bold;">Examples</a> directory for practical use cases.</li>
<li style="margin-bottom: 10px;"><span style="font-size: 1.2em; color: #555;"></span> Read our <a href="https://docs.axolotl.ai/docs/debugging.html" style="color: #FF8C00; text-decoration: none; font-weight: bold;">Debugging Guide</a> for troubleshooting tips.</li>
<li style="margin-bottom: 0px;"><span style="font-size: 1.2em; color: #007bff;">✉</span> Need dedicated support? Please contact <a href="mailto:wing@axolotl.ai" style="color: #007bff; text-decoration: none; font-weight: bold;">wing@axolotl.ai</a> for professional assistance options.</li>
</ul>
## 📚 Documentation
<h2 style="color: #FF1493;"><span style="margin-right: 10px;">🌟</span> Contribute to Axolotl!</h2>
<p style="font-size: 1.1em;">
Contributions are always welcome and highly appreciated! Axolotl thrives on community support. Please see our <a href="https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md" style="color: #FF1493; text-decoration: none; font-weight: bold;">Contributing Guide</a> for details on how you can help make Axolotl even better.
</p>
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [Multipacking](https://docs.axolotl.ai/docs/multipack.html)
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
<div align="center" style="margin-top: 40px; padding: 25px; background-color: #f8f8f8; border-radius: 12px; border: 1px solid #eee;">
<h2 style="color: #FF69B4; margin-bottom: 20px;">❤️ Our Esteemed Sponsors</h2>
<p style="font-size: 1.1em; color: #555;">A huge thank you to our visionary sponsors who provide the essential resources to keep Axolotl at the forefront of LLM fine-tuning:</p>
<a href="https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl" target="_blank" style="display: inline-block; margin: 20px;">
<img src="https://assets-global.website-files.com/6247c4c1d68352614b7e87ae/63b27b3b44b82d02c8163f4f_logo-dark-square.png" alt="Modal Logo" width="180" style="vertical-align: middle; border-radius: 8px; box-shadow: 0 4px 10px rgba(0,0,0,0.15);"/>
</a>
<p style="font-size: 0.9em; color: #777; margin-top: 20px;">
<strong>Modal:</strong> Revolutionizing cloud computing for Gen AI. Run jobs, deploy models, and fine-tune LLMs at scale with ease.
</p>
<p style="font-size: 1em; color: #555; margin-top: 30px;">
Interested in powering the future of Axolotl? <span style="font-weight: bold; color: #FF69B4;">Become a sponsor!</span> Contact us at <a href="mailto:wing@axolotl.ai" style="color: #007bff; text-decoration: none;">wing@axolotl.ai</a>
</p>
</div>
## 🤝 Getting Help
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory
- Read our [Debugging Guide](https://docs.axolotl.ai/docs/debugging.html)
- Need dedicated support? Please contact [wing@axolotl.ai](mailto:wing@axolotl.ai) for options
## 🌟 Contributing
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## Supported Models
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## ❤️ Sponsors
Thank you to our sponsors who help make Axolotl possible:
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
<h2 style="color: #6A5ACD;"><span style="margin-right: 10px;">📜</span> License</h2>
<p style="font-size: 1.1em;">
This project is proudly licensed under the <span style="font-weight: bold; color: #6A5ACD;">Apache 2.0 License</span>. See the <a href="LICENSE" style="color: #007bff; text-decoration: none;">LICENSE</a> file for full details.
</p>

View File

@@ -17,7 +17,9 @@ quartodoc:
- convert
- prompt_tokenizers
- logging_config
- core.trainer_builder
- core.builders.base
- core.builders.causal
- core.builders.rl
- core.training_args
- core.chat.messages
- core.chat.format.chatml
@@ -43,6 +45,7 @@ quartodoc:
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- title: Trainers
desc: Training implementations
contents:
@@ -54,6 +57,15 @@ quartodoc:
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
- core.trainers.utils
- title: Model Loading
desc: Functionality for loading and patching models, tokenizers, etc.
contents:
- loaders.model
- loaders.tokenizer
- loaders.processor
- loaders.adapter
- loaders.patch_manager
- loaders.constants
- title: Mixins
desc: Mixin classes for augmenting trainers
contents:
@@ -117,17 +129,16 @@ quartodoc:
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
- monkeypatch.gradient_checkpointing.offload_disk
- title: Utils
desc: Utility functions
contents:
- utils.models
- utils.tokenization
- utils.chat_templates
- utils.lora
- utils.lora_embeddings
- utils.model_shard_quant
- utils.bench
- utils.freeze
@@ -138,8 +149,7 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- utils.quantization
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
@@ -189,12 +199,14 @@ quartodoc:
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
- utils.callbacks.qat
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
favicon: favicon.jpg
google-analytics: "G-9KYCVJBNMQ"
navbar:
logo: image/axolotl_logo_digital_white.svg
title: false
@@ -247,6 +259,8 @@ website:
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- section: "Core Concepts"
contents:

52
cicd/Dockerfile-uv.jinja Normal file
View File

@@ -0,0 +1,52 @@
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py --uv | sh
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
@@ -55,7 +55,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
GPU_CONFIG = f"H100:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):

View File

@@ -8,8 +8,9 @@ import tempfile
import jinja2
import modal
import modal.experimental
from jinja2 import select_autoescape
from modal import App, Image
from modal import App
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -17,14 +18,15 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
@@ -38,11 +40,11 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
cicd_image = modal.experimental.raw_dockerfile_image(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
# context_mount=None,
force_build=True,
gpu="A10G",
# gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
@@ -55,7 +57,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):

View File

@@ -0,0 +1,31 @@
{
"compile": {
"disable": false,
"backend": "inductor"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

40
docker/Dockerfile-uv-base Normal file
View File

@@ -0,0 +1,40 @@
ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
ENV UV_TORCH_BACKEND="cu${CUDA}"
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& git lfs install --skip-repo \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:${PATH}"
RUN uv python install ${PYTHON_VERSION}
WORKDIR /workspace
RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
fi

View File

@@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
### quantize
Quantizes a model using the quantization configuration specified in your YAML file.
```bash
axolotl quantize config.yml
```
See [Quantization](./quantize.qmd) for more details.
## Legacy CLI Usage

View File

@@ -27,6 +27,8 @@ trust_remote_code:
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.
tokenizer_use_mistral_common:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
@@ -65,6 +67,20 @@ bnb_config_kwargs:
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# quantization aware training
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
# post-training quantization
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -98,8 +114,10 @@ plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
# HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -157,6 +175,10 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the tools (default: "tools")
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
field_tools: tools
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system
@@ -221,7 +243,7 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
Deduplicates datasets and test_datasets with identical entries.
# Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with.
@@ -270,10 +292,25 @@ trl:
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
top_k: # Optional[int]. Top-k sampling for the generation policy.
min_p: # Optional[float]. Minimum probability for the generation policy.
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
# reward modelling: `True` or `False`
@@ -483,6 +520,7 @@ output_dir: ./completed-model
# setting to `auto` will enable torch compile when torch>=2.5.1
torch_compile: # Optional[Union[Literal["auto"], bool]]
torch_compile_backend: # Optional[str]
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
# Training hyperparameters
@@ -529,7 +567,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
# Save model as safetensors (require safetensors package). Default True
save_safetensors:
# Whether to mask out or include the human's prompt from the training labels
@@ -551,7 +589,24 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
# Valid values are driven by the Transformers SchedulerType class, see:
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
# Valid values include
# - 'linear'
# - 'cosine' (default)
# - 'cosine_with_restarts'
# - 'polynomial'
# - 'constant'
# - 'constant_with_warmup'
# - 'inverse_sqrt'
# - 'reduce_lr_on_plateau'
# - 'cosine_with_min_lr'
# - 'warmup_stable_decay'
# Additional schedulers include:
# - 'one_cycle'
# - 'rex'
lr_scheduler:
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -569,7 +624,7 @@ lr_div_factor: # Learning rate div factor
#
# Valid values for 'optimizer' include:
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_fused (default)
# - adamw_torch_xla
# - adamw_torch_npu_fused
# - adamw_apex_fused

View File

@@ -52,7 +52,9 @@ We recommend checking the below examples for other usecases.
### Examples
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
#### Training on last message
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
@@ -66,7 +68,9 @@ datasets:
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
#### Overriding default chat template
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
@@ -76,7 +80,13 @@ datasets:
roles_to_train: ["assistant"] # default value
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
::: {.callout-note}
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
:::
#### Using default chat template with fallback
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
@@ -85,7 +95,9 @@ datasets:
type: chat_template
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
#### Custom Jinja template
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
@@ -100,7 +112,9 @@ datasets:
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
#### Using template with different token for EOT and EOS
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
```yaml
eot_tokens:
@@ -125,7 +139,7 @@ Using `eot_tokens` requires each token that exists in `chat_template` to be a si
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
:::
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
```yaml
eot_tokens:
@@ -145,7 +159,73 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva
:::
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
#### Using tool use
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
```json
{
"tools": [
{
"type": "...",
"function": {
"name": "...",
"description": "...",
"parameters": {
"type": "...",
"properties": {
// ...
},
"required": ["..."],
},
},
},
],
"messages": [
// ...
{
"role": "assistant", // call the function via assistant
"tool_calls": [
{
"type": "function",
"function": {
"name": "...",
"arguments": {
"...": "...",
}
}
}
]
},
{
"role": "tool",
"name": "...",
"content": "..."
},
],
}
```
::: {.callout-note}
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
```yaml
chat_template: llama4
datasets:
- path: ...
type: chat_template
# field_tools: tools # default is `tools`
```
::: {.callout-tip}
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
:::
#### Using fine-grained control over token masking
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -196,7 +276,9 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:

View File

@@ -36,10 +36,6 @@ It is typically recommended to save your dataset as `.jsonl` due to its flexibil
Axolotl supports loading from a Hugging Face hub repo or from local files.
::: {.callout-important}
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
:::
### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
@@ -77,18 +73,21 @@ datasets:
type: completion
```
From local files (either example works):
From local files:
```yaml
datasets:
- path: A.jsonl
type: completion
- path: json
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
- path: B.jsonl
type: completion
```
::: {.callout-important}
For `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!
:::
### Pre-training dataset configuration tips
#### Setting max_steps

View File

@@ -54,7 +54,7 @@ datasets:
#### Files
Usually, to load a JSON file, you would do something like this:
To load a JSON file, you would do something like this:
```python
from datasets import load_dataset
@@ -66,20 +66,12 @@ Which translates to the following config:
```yaml
datasets:
- path: json
data_files: /path/to/your/file.jsonl
```
However, to make things easier, we have added a few shortcuts for loading local dataset files.
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
```yaml
datasets:
- path: /path/to/your/file.jsonl
- path: data.json
ds_type: json
```
In the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.
This works for CSV, JSON, Parquet, and Arrow files.
::: {.callout-tip}

View File

@@ -9,7 +9,7 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8.
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
:::
## Base
@@ -32,11 +32,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
@@ -77,12 +76,10 @@ Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
- `0.9.2`
## Cloud

View File

@@ -110,3 +110,17 @@ description: Frequently asked questions
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
**Q: `Data processing error: CAS service error`**
> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`
**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**
> A: Depending on the version of torch, you may need to include this in your YAML:
> ```yaml
> flex_attn_compile_kwargs:
> dynamic: false
> mode: max-autotune-no-cudagraphs
> ```

View File

@@ -180,7 +180,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
- PyTorch ≥2.5.1
## Installation Methods {#sec-installation-methods}
@@ -41,6 +41,40 @@ installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
Install uv if not already installed
```{.bash}
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
then create the venv and activate
```{.bash}
export UV_TORCH_BACKEND=cu126
uv venv --no-project --relocatable
source .venv/bin/activate
```
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
```
### Edge/Development Build {#sec-edge-build}
For the latest features between releases:

View File

@@ -84,6 +84,10 @@ lora_qkv_kernel: true
lora_o_kernel: true
```
::: {.callout-note}
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -43,7 +43,7 @@ datasets:
# leave the vision model and vision tower frozen
# load_in_8bit: true
adapter: lora
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# (optional) if you want to resize images to a set size
image_size: 512

32
docs/qat.qmd Normal file
View File

@@ -0,0 +1,32 @@
---
title: "Quantization Aware Training (QAT)"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
support for QAT and post-training quantization (PTQ) in axolotl.
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
## Configuring QAT in Axolotl
To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.

53
docs/quantize.qmd Normal file
View File

@@ -0,0 +1,53 @@
---
title: "Quantization with torchao"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
::: {.callout-note}
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
:::
## Configuring Quantization in Axolotl
Quantization is configured using the `quantization` key in your configuration file.
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
output_dir: # The path to the output directory.
```
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
you used to train the model:
```yaml
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```
```bash
axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.

View File

@@ -16,7 +16,8 @@ feedback. Various methods include, but not limited to:
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
## RLHF using Axolotl
@@ -499,7 +500,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
:::
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
@@ -582,7 +583,20 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
```yaml
trl:
loss_type: dr_grpo
# Normalizes loss based on max completion length (default: 256)
max_completion_length:
```
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### SimPO

View File

@@ -28,7 +28,7 @@ pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -30,7 +30,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -29,7 +29,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,79 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./outputs/qat_out/
sample_packing: true
pad_to_sequence_len: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -5,6 +5,10 @@ tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
load_in_8bit: true
load_in_4bit: false

View File

@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
@@ -38,6 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,71 @@
# Finetune Magistral Small with Axolotl
Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
```
2. Download the example config:
```bash
axolotl fetch examples
```
3. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml
```
This config uses about 24GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,72 @@
base_model: mistralai/Magistral-Small-2506
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing:
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true

View File

@@ -0,0 +1,63 @@
base_model: mistralai/Magistral-Small-2506
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -27,7 +27,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,78 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out/
sequence_len: 2048
sample_packing: true
flex_attention: true
pad_to_sequence_len: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 256
fake_quant_after_n_steps: 1000
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
max_steps: 2000
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

After

Width:  |  Height:  |  Size: 4.7 KiB

View File

@@ -6,21 +6,20 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.9
liger-kernel==0.5.10
# END section
packaging==23.2
huggingface_hub==0.31.0
huggingface_hub==0.32.2
peft==0.15.2
transformers==4.51.3
transformers==4.52.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.1
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.1.0
hqq==0.2.5
accelerate==1.7.0
datasets==3.6.0
deepspeed>=0.17.0
trl==0.18.1
hf_xet==1.1.2
optimum==1.16.2
hf_transfer
@@ -63,8 +62,10 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.9.0
torchao==0.10.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.6.0

View File

@@ -9,6 +9,8 @@ except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
USE_UV = "--uv" in sys.argv[1:]
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
@@ -23,7 +25,9 @@ if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
)

View File

@@ -11,7 +11,7 @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
```
cd /workspace

View File

@@ -1,11 +1,15 @@
# noqa
# pylint: skip-file
import sys
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
@@ -31,6 +35,7 @@ elif v < V("2.6.0"):
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print(
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
)

View File

@@ -118,7 +118,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.15.4",
"deepspeed==0.17.0",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.10.0.dev0"
__version__ = "0.10.0"

View File

@@ -28,7 +28,6 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
@@ -89,6 +88,26 @@ class VllmServeCliArgs:
},
)
enable_reasoning: Optional[bool] = field(
default=None,
)
reasoning_parser: Optional[str] = field(
default=None,
)
@dataclass
class QuantizeCliArgs:
"""Dataclass with CLI arguments for `axolotl quantize` command."""
base_model: Optional[str] = field(default=None)
weight_dtype: Optional[str] = field(default=None)
activation_dtype: Optional[str] = field(default=None)
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
@dataclass
class EvaluateCliArgs:

View File

@@ -1,6 +1,5 @@
"""Various checks for Axolotl CLI."""
import logging
import os
from pathlib import Path
@@ -8,7 +7,9 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_accelerate_default_config() -> None:

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1"
docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"

View File

@@ -1,7 +1,6 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
@@ -22,11 +21,12 @@ from axolotl.utils.config import (
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__, use_environ=True)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
LOG.info("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
LOG.info(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
LOG.info("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
LOG.info("Invalid input. Please enter a number.")
return chosen_file

View File

@@ -1,6 +1,5 @@
"""CLI to run evaluation on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -17,8 +16,9 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:

View File

@@ -1,7 +1,6 @@
"""CLI to run inference on a trained model."""
import importlib
import logging
import sys
from pathlib import Path
from threading import Thread
@@ -22,8 +21,9 @@ from axolotl.utils.chat_templates import (
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def get_multi_line_input() -> str:

View File

@@ -2,7 +2,6 @@
# pylint: disable=redefined-outer-name
import logging
import os
import subprocess # nosec B404
import tempfile
@@ -17,6 +16,7 @@ import axolotl
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
@@ -30,8 +30,11 @@ from axolotl.cli.utils import (
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
@@ -176,7 +179,7 @@ def train(
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
@@ -333,6 +336,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize
do_quantize(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))

View File

@@ -1,20 +1,18 @@
"""CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_merge_lora(*, cfg: DictDefault) -> None:
@@ -68,12 +66,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(
config,

View File

@@ -1,7 +1,6 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json
import logging
import os
import shutil
from pathlib import Path
@@ -11,7 +10,6 @@ import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
import transformers
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@@ -24,11 +22,11 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
@@ -197,11 +195,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -1,6 +1,5 @@
"""CLI to run preprocessing of a dataset."""
import logging
import warnings
from pathlib import Path
from typing import Union
@@ -20,9 +19,10 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:

View File

@@ -0,0 +1,90 @@
"""
CLI to post-training quantize a model using torchao
"""
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = get_logger(__name__)
def do_quantize(
config: Union[Path, str],
cli_args: dict,
):
"""
Quantizes a model's model's weights
Args:
config (Union[Path, str]): The path to the config file
cli_args (dict): Additional command-line arguments
"""
print_axolotl_text_art()
cfg = load_cfg(config)
if cfg.qat and cfg.quantization:
raise ValueError(
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
)
if cfg.qat:
quantize_cfg = cfg.qat
elif cfg.quantization:
quantize_cfg = cfg.quantization
else:
raise ValueError(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
quantize_embedding = (
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
f"\tweight_dtype: {weight_dtype}\n"
f"\tactivation_dtype: {activation_dtype}\n"
f"\tgroup_size: {group_size}\n"
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model."""
import gc
import logging
import os
from pathlib import Path
from typing import Union
@@ -22,8 +21,6 @@ from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""

View File

@@ -4,7 +4,6 @@ import concurrent.futures
import dataclasses
import hashlib
import json
import logging
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -23,8 +22,9 @@ from transformers import (
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None):

View File

@@ -2,14 +2,27 @@
CLI to start the vllm server for online RL
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg
@dataclass
class AxolotlScriptArguments(ScriptArguments):
"""
Additional arguments for the VLLM server
"""
reasoning_parser: str = field(default="", kw_only=True)
enable_reasoning: bool | None = field(default=None, kw_only=True)
def do_vllm_serve(
config: Union[Path, str],
cli_args: dict,
@@ -24,6 +37,7 @@ def do_vllm_serve(
Returns:
process_id: the process id of the started VLLM server
"""
patch_vllm_worker()
cfg = load_cfg(config)
model = cfg.base_model
@@ -43,9 +57,16 @@ def do_vllm_serve(
enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
)
reasoning_parser = (
cli_args.get("reasoning_parser") or cfg.vllm.reasoning_parser or ""
)
enable_reasoning = (
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
vllm_script_args = ScriptArguments(
model,
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,
host=host,
port=port,
@@ -53,5 +74,67 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker

View File

@@ -1,5 +1,3 @@
"""
Various shared constants
"""
"""Various shared constants"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -1,23 +1,21 @@
"""Dataset loading utilities."""
import logging
import math
import random
from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
@dataclass
@@ -30,16 +28,7 @@ class TrainDatasetMeta:
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
@@ -51,44 +40,37 @@ def load_datasets(
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
"""Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
debug: Whether to print out tokenization of sample. This is duplicated in
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
if (
cfg.debug
or getattr(cli_args, "debug", False)
or getattr(cli_args, "debug_text_only", False)
or getattr(cli_args, "debug_num_examples", 0) > 0
or debug
):
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
@@ -113,13 +95,10 @@ def load_datasets(
def load_preference_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
"""Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
@@ -130,23 +109,28 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl is RLType.GRPO:
total_num_steps = None
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
if cli_args.debug or cfg.debug:
total_num_steps: int | None = None
if cfg.rl is not RLType.GRPO:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if (cli_args and cli_args.debug) or cfg.debug:
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
dataset=train_samples,
tokenizer=tokenizer,
num_examples=num_examples,
text_only=text_only,
rl_mode=True,
)

View File

@@ -0,0 +1,6 @@
"""Trainer builder classes"""
from .causal import HFCausalTrainerBuilder
from .rl import HFRLTrainerBuilder
__all__ = ["HFCausalTrainerBuilder", "HFRLTrainerBuilder"]

View File

@@ -0,0 +1,508 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for trainer builder"""
import abc
import importlib
import logging
import sys
from abc import abstractmethod
from contextlib import suppress
from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
self._train_dataset = None
self._eval_dataset = None
self._model_ref = None
self._peft_config = None
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@property
def peft_config(self):
return self._peft_config
@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config
@abstractmethod
def build(self, total_num_steps):
pass
def get_callbacks(self) -> list[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
callbacks.append(GPUStatsCallback(cfg=self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_ratio = self.cfg.warmup_ratio
elif total_num_steps:
warmup_steps = min(int(0.03 * total_num_steps), 100)
else:
warmup_ratio = 0.03
if warmup_steps == 1:
warmup_steps = 2
if self.cfg.logging_steps is not None:
training_args_kwargs["logging_steps"] = self.cfg.logging_steps
else:
training_args_kwargs["logging_steps"] = (
500 # transformers defaults to 500
if not total_num_steps
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine"
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
else:
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
def _configure_optimizer(self, training_args_kwargs: dict, trainer_kwargs: dict):
def _configure_custom_optimizer(
training_args_kwargs: dict, trainer_kwargs: dict
):
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_args_kwargs["learning_rate"],
"weight_decay": training_args_kwargs["weight_decay"],
}
# Adam-specific kwargs
adam_kwargs: dict = {}
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
"adam_beta2"
):
adam_kwargs["betas"] = (
training_args_kwargs.get("adam_beta1"),
training_args_kwargs.get("adam_beta2"),
)
if training_args_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
beta3 = training_args_kwargs.get("adam_beta3", 0.9999)
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
# Note: This is not used in training_args_kwargs, but in trainer_kwargs
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
_configure_custom_optimizer(training_args_kwargs, trainer_kwargs)
else:
# Use transformers' optimizer
training_args_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_args_kwargs["optim_args"] = optim_args
if (
self.cfg.optimizer == "adamw_anyprecision"
and Path(self.cfg.torchdistx_path).exists()
):
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
def _configure_hub_parameters(self, training_args_kwargs: dict):
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
# eval_strategy and eval_steps
if not self.eval_dataset and self.cfg.val_set_size == 0:
# do not eval if no eval_dataset and val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
training_args_kwargs["eval_on_start"] = True
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
training_args_kwargs["eval_on_start"] = True
def _configure_reporting(self, training_args_kwargs: dict):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_args_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_args_kwargs["run_name"] = None
def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
def _set_base_training_args(
self, total_num_steps
) -> tuple[dict[str, Any], dict[str, Any]]:
training_args_kwargs: dict[str, Any] = {}
trainer_kwargs: dict[str, Any] = {}
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)
self._configure_precision_settings(training_args_kwargs)
self._configure_save_and_eval_strategy(training_args_kwargs)
self._configure_gradient_checkpointing(training_args_kwargs)
# set arg into trainer_args_kwargs with same name if value not None
for arg in [
# optim/scheduler
"adam_beta1",
"adam_beta2",
"adam_beta3",
"adam_epsilon",
"adam_epsilon2",
"cosine_min_lr_ratio",
"cosine_constant_lr_ratio",
"optim_target_modules",
# trainer
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"gradient_accumulation_steps",
"learning_rate",
"embedding_lr",
"embedding_lr_scale",
"lr_groups",
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
self._configure_reporting(training_args_kwargs)
self._configure_hub_parameters(training_args_kwargs)
self._configure_scheduler(training_args_kwargs)
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
self._configure_torch_compile(training_args_kwargs)
return training_args_kwargs, trainer_kwargs

View File

@@ -0,0 +1,488 @@
"""Builder for causal trainers"""
import inspect
import math
import os
from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlMambaTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
# TODO: check if can move to base class
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
"""
Gets the trainer class for the given configuration.
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = (
self.cfg.lr_quadratic_warmup
)
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs["dataloader_drop_last"] = (
self.cfg.dataloader_drop_last
)
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = (
self.cfg.metric_for_best_model
)
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = (
self.cfg.ddp_broadcast_buffers
)
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs["auto_find_batch_size"] = (
self.cfg.auto_find_batch_size
)
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not self.cfg.flash_attention
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = (
self.cfg.lisa_step_interval
)
training_arguments_kwargs["lisa_layers_attribute"] = (
self.cfg.lisa_layers_attribute
)
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs["neftune_noise_alpha"] = (
self.cfg.neftune_noise_alpha
)
if self.cfg.accelerator_config:
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_arguments_kwargs.update(plugin_training_args)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
multiple = 64
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer
def build_collator(
self,
training_args, # type: "AxolotlTrainingArguments" # type: ignore
is_eval=False,
**kwargs,
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
use_batch_sampler_collator = False
if is_eval is False and training_args.sample_packing:
use_batch_sampler_collator = True
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True
collator: Type[
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
]
]
collator_args = [self.tokenizer]
collator_cls_and_kwargs = None
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
self.cfg, is_eval=is_eval
)
if collator_cls_and_kwargs:
collator = collator_cls_and_kwargs[0]
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama
if (
self.cfg.flex_attention
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
)
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
else:
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
return collator(
*collator_args,
**kwargs,
)

View File

@@ -0,0 +1,238 @@
"""Builder for RLHF trainers"""
import inspect
from pathlib import Path
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
LOG = get_logger(__name__)
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def _get_trainer_cls(self, trainer_kwargs: dict):
"""
Returns trainer_cls and trainer_cls_args
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None:
return trainer_cls, trainer_cls_args
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args.append(self.model_ref)
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):
"""
Returns training_args and trainer_kwargs
"""
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args, trainer_kwargs
def build(self, total_num_steps):
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# TODO: build PPOConfig
raise NotImplementedError("PPO trainer builder is not implemented yet.")

View File

@@ -156,7 +156,6 @@ class Messages(BaseModel):
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:

File diff suppressed because it is too large Load Diff

View File

@@ -4,11 +4,10 @@
from __future__ import annotations
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Literal
from functools import partial, wraps
from typing import Callable, Literal, Optional
import datasets
import torch
@@ -26,6 +25,7 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
CheckpointSaveMixin,
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
@@ -34,12 +34,16 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -101,7 +105,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
sampler = MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -111,9 +115,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
num_processes=self.args.dataset_num_proc,
)
def _get_train_sampler(self) -> Sampler | None:
len(sampler)
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
"""
Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential).
@@ -137,7 +147,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
if use_sample_packing:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
dataset=train_dataset,
)
return base_sampler
@@ -150,8 +160,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False
@@ -172,125 +180,93 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
return base_sampler
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
def _get_dataloader(
self,
dataset: Dataset,
description: str,
batch_size: int,
sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
is_training: bool = False,
dataloader_key: Optional[str] = None,
) -> DataLoader:
"""Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
params = {
data_collator = self.data_collator if is_training else self.eval_data_collator
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if isinstance(dataset, datasets.Dataset):
if is_training:
if not self.args.sample_packing or self.args.pretraining:
dataset = self._remove_unused_columns(
dataset, description="training"
)
elif (
not is_training
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
batch_size = (
batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
else:
dataset = self._remove_unused_columns(dataset, description=description)
else:
data_collator = self._get_collator_with_removed_columns(
self.data_collator, description=description
)
dataloader_params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
dataloader_params["drop_last"] = get_not_null(
self.args.dataloader_drop_last, True
)
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
del dataloader_params["drop_last"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if is_training:
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(dataloader)
dataloader = DataLoader(dataset, **dataloader_params)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version for eval dataloaders.
# fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
# fmt: on
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
return dataloader
return super().get_eval_dataloader(eval_dataset)
return self.accelerator.prepare(dataloader)
def _get_bench_sampler(
self, bench_dataset: Dataset

View File

@@ -22,10 +22,19 @@ class DPOStrategy:
training_args_kwargs = {}
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
return training_args_kwargs

View File

@@ -14,3 +14,5 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
dpo_norm_loss: bool | None = False

View File

@@ -5,65 +5,31 @@ from functools import wraps
from typing import Any, Dict, Union
import torch
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
):
"""Extend the base DPOTrainer for axolotl helpers."""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
# pylint: disable=duplicate-code
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
# pylint: disable=duplicate-code
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
@@ -117,3 +83,20 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
gc.collect()
torch.cuda.empty_cache()
return loss
def concatenated_forward(
self,
model: nn.Module,
batch: dict[str, Union[list, torch.LongTensor]],
is_ref_model: bool = False,
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -2,7 +2,6 @@
import importlib
import inspect
import logging
from typing import Any
from trl.trainer.grpo_trainer import RewardFunc
@@ -13,9 +12,10 @@ from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class GRPOStrategy:
@@ -69,6 +69,9 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
@@ -106,7 +109,9 @@ class GRPOStrategy:
return grpo_args_kwargs
@classmethod
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
def set_trainer_args(
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
@@ -123,6 +128,7 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
return trainer_kwargs
@classmethod
@@ -132,7 +138,7 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc"]
return ["dataset_num_proc", "max_length"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
@@ -167,4 +173,4 @@ class GRPOStrategy:
LOG.info(
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func
return reward_func_fqn

View File

@@ -12,3 +12,5 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None

View File

@@ -3,6 +3,7 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from typing import Any
import datasets
@@ -43,6 +44,7 @@ from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available():
@@ -50,11 +52,49 @@ if is_peft_available():
from peft import PeftConfig
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
@@ -77,6 +117,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: "PeftConfig | None" = None,
optimizer_cls_and_kwargs: tuple[type, dict] | None = None,
):
# First call the superclass constructor with all arguments
super().__init__(
@@ -90,6 +131,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
)
# Get number of SP groups (number of processes divided by SP degree)
@@ -131,6 +173,13 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
f"the valid values for the number of generations are: {possible_values}."
)
self.sp_group = None
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.local_rank = 0
self.local_world_size = 1
def train(self, *args, **kwargs):
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
@@ -138,6 +187,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
return super().train(*args, **kwargs)
def _get_train_sampler(self) -> Sampler:
effective_batch_size = (
self.args.per_device_train_batch_size

View File

@@ -3,6 +3,7 @@
# pylint: disable=unused-import
# flake8: noqa
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin

View File

@@ -0,0 +1,21 @@
"""Custom handling to not fail training if fsdp optimizer is not savable"""
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CheckpointSaveMixin(Trainer):
"""Mixin to handle saving the optimizer and scheduler if they are not savable."""
def _save_optimizer_and_scheduler(self, output_dir):
try:
super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc:
LOG.warning(
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible."
)

View File

@@ -1,18 +1,17 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.utils.logging import get_logger
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class OptimizerMixin(Trainer):
@@ -199,3 +198,20 @@ class OptimizerMixin(Trainer):
)
return self.optimizer
class OptimizerInitMixin:
"""
Mixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not
accept optimizer_cls_and_kwargs as kwarg in constructor.
"""
def __init__(self, *args, **kwargs):
optimizer_cls_and_kwargs = kwargs.pop("optimizer_cls_and_kwargs", None)
super().__init__(*args, **kwargs)
if (
optimizer_cls_and_kwargs
and self.optimizer_cls_and_kwargs is None
and self.optimizer is None
):
self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs

View File

@@ -6,7 +6,6 @@ See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
@@ -17,7 +16,9 @@ from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class RngLoaderMixin(Trainer):

View File

@@ -1,12 +1,11 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -14,7 +13,7 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class SchedulerMixin(Trainer):
@@ -80,13 +79,15 @@ class SchedulerMixin(Trainer):
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
min_lr=0 if not use_cosine_min_lr else (
self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
@@ -115,9 +116,11 @@ class SchedulerMixin(Trainer):
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
LOG.warning(
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler # type: ignore

View File

@@ -1,7 +1,5 @@
"""Module for TRL PPO trainer"""
from typing import Literal, Union
import torch
from tqdm import tqdm
from trl import (
@@ -14,6 +12,7 @@ from trl import (
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -75,87 +74,19 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
chosen_rewards
).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
rejected_rewards
).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
reward_accuracies
).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
)
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
)
metrics[f"{prefix}log_odds_ratio"] = (
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
)
metrics[f"{prefix}log_odds_chosen"] = (
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
)
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -163,89 +94,19 @@ class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = (
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
)
metrics[f"{prefix}rewards/rejected"] = (
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
)
metrics[f"{prefix}rewards/accuracies"] = (
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
)
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
.mean()
.item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
)
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
class AxolotlRewardTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
):
"""
Extend the base RewardTrainer for axolotl helpers
"""
@@ -253,7 +114,9 @@ class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

View File

@@ -2,244 +2,17 @@
extra axolotl specific training args
"""
from dataclasses import dataclass, field
from typing import Optional
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Type
from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section
AxolotlTrainingMixins: Type = merge_training_args()
@dataclass

View File

@@ -0,0 +1,224 @@
"""
Base Axolotl Training Mixins shared across various trainer configs
"""
from dataclasses import dataclass, field
from typing import Optional
from PIL.Image import Resampling
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "The number of processes to use for data processing"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
# kd_ce_alpha: Optional[float] = field(
# default=None,
# metadata={
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
# },
# )
#
# kd_alpha: Optional[float] = field(
# default=1.0,
# metadata={"help": "The alpha scaling parameter for KD loss"},
# )
#
# kd_temperature: Optional[float] = field(
# default=1.0,
# metadata={
# "help": "the temperature parameter for KL divergence loss when using KD"
# },
# )
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section

View File

@@ -1,12 +1,12 @@
"""Module containing Dataset functionality"""
import logging
import os
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
@@ -15,25 +15,25 @@ from .prompt_tokenizers import PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
"""Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer: The prompt tokenizing method for processing the data.
dataset: Dataset with text files.
process_count: Number of processes to use for tokenizing.
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
@@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
@@ -75,14 +82,14 @@ class TokenizedPromptDataset(Dataset):
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset],
dataset: Dataset | IterableDataset,
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = dataset.features.keys()
features = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
@@ -93,12 +100,13 @@ def wrap_dataset_for_tokenized_prompt(
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
@@ -109,7 +117,7 @@ class ConstantLengthDataset(IterableDataset):
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
@@ -173,7 +181,10 @@ class ConstantLengthDataset(IterableDataset):
}
else:
LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],

View File

@@ -7,7 +7,6 @@ from pathlib import Path
from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
@@ -17,6 +16,7 @@ from axolotl.train import (
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

View File

@@ -22,7 +22,7 @@ from __future__ import annotations
import collections
import importlib
import logging
import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
@@ -31,6 +31,9 @@ from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True)
if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
@@ -39,31 +42,39 @@ if TYPE_CHECKING:
class BasePlugin:
"""Base class for all plugins. Defines the interface for plugin methods.
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but
A plugin is a reusable, modular, and self-contained piece of code that extends
the functionality of Axolotl. Plugins can be used to integrate third-party models,
modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and
implement the required methods.
Note:
Plugin methods include:
- register(cfg): Registers the plugin with the given configuration.
- load_datasets(cfg): Loads and preprocesses the dataset for training.
- pre_model_load(cfg): Performs actions before the model is loaded.
- post_model_build(cfg, model): Performs actions after the model is loaded, but
before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded,
- pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
- post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
- post_model_load(cfg, model): Performs actions after the model is loaded,
inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is
- post_trainer_create(cfg, trainer): Performs actions after the trainer is
created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
- create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
- create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
- add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
- add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
training.
"""
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg): # pylint: disable=unused-argument
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration.
Args:
@@ -73,6 +84,11 @@ class BasePlugin:
def get_input_args(self) -> str | None:
"""Returns a pydantic model for the plugin's input arguments."""
def get_training_args_mixin(self) -> str | None:
"""
Returns a dataclass model for the plugin's training arguments.
"""
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -148,6 +164,31 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
"""
Returns custom training arguments to set on TrainingArgs.
Args:
cfg: The global axolotl configuration.
Returns:
object: dict containing the training arguments.
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
Args:
cfg: The global axolotl configuration.
is_eval: Whether this is an eval split.
Returns:
class: The class for the collator.
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -268,17 +309,18 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager:
class PluginManager: # pylint: disable=too-many-public-methods
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
Attributes:
plugins: A list of loaded plugins.
Methods:
get_instance(): Static method to get the singleton instance of `PluginManager`.
register(plugin_name: str): Registers a new plugin by its name.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
Note:
Key methods include:
- get_instance(): Static method to get the singleton instance of `PluginManager`.
- register(plugin_name: str): Registers a new plugin by its name.
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
@@ -322,12 +364,15 @@ class PluginManager:
ImportError: If the plugin module cannot be imported.
"""
try:
logging.info(f"Attempting to load plugin: {plugin_name}")
LOG.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError as exc:
LOG.error(f"Failed to load plugin: {plugin_name}")
# print stacktrace
traceback.print_exc()
print(f"Error: {exc}")
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
@@ -342,6 +387,20 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
def get_training_args_mixin(self):
"""
Returns a list of dataclasses for all registered plugins' training args mixins'
Returns:
list[str]: A list of dataclsses
"""
training_args = []
for plugin in self.plugins.values():
training_args_from_plugin = plugin.get_training_args_mixin()
if training_args_from_plugin is not None:
training_args.append(training_args_from_plugin)
return training_args
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -431,6 +490,42 @@ class PluginManager:
return trainer_cls
return None
def get_training_args(self, cfg):
"""
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The training arguments
"""
training_args_kwargs = {}
for plugin in self.plugins.values():
training_args = plugin.get_training_args(cfg)
if training_args is not None:
training_args_kwargs.update(training_args)
return training_args_kwargs
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
"""
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
is_eval (bool): Whether this is an eval split.
Returns:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
if collator is not None:
collator_cls, collator_kwargs = collator
return collator_cls, collator_kwargs
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.
@@ -534,7 +629,6 @@ class PluginManager:
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)

View File

@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
This was moved here to prevent circular imports.
"""
from typing import Any, Dict, List
from typing import Any, Dict, List, Type
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
@@ -61,3 +61,43 @@ def merge_input_args():
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
def merge_training_args() -> Type:
"""
Merges training arguments from registered plugins with the base TrainingArguments.
This function retrieves the training arguments from registered plugins using the PluginManager.
It then dynamically creates new classes, AxolotlTrainingMixins,
that inherit from the base configurations and include the training arguments from the plugins.
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
mixin_classes = []
dynamic_input = ""
for plugin_args in training_args_mixins:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
mixin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -24,6 +24,14 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
## Usage
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
```bash
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
pip3 install --no-build-isolation -e .
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -19,17 +19,16 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
import logging
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
LOG = get_logger(__name__, use_environ=True)
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
@@ -76,10 +75,9 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch,
)
if is_main_process(use_environ=True):
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
# The patch checks model_type internally
cce_patch(cfg.model_config_type)

View File

@@ -15,12 +15,13 @@
"""
Module for handling Cut Cross Entropy input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CutCrossEntropyArgs(BaseModel):

View File

@@ -15,23 +15,14 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import (
MLLAMA_INPUTS_DOCSTRING,
_prepare_cross_attention_mask,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -164,10 +155,6 @@ def cce_forward(
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -2,15 +2,15 @@
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast")
LOG = get_logger(__name__)
class GrokfastCallbackHandler(TrainerCallback):

View File

@@ -15,7 +15,12 @@
"""
Plugin init to add KD support to Axolotl.
"""
from typing import Any
from transformers import Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
@@ -28,9 +33,75 @@ class KDPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
def get_training_args_mixin(self):
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None
def get_training_args(self, cfg):
return {
"kd_ce_alpha": cfg.kd_ce_alpha,
"kd_alpha": cfg.kd_alpha,
"kd_temperature": cfg.kd_temperature,
"kd_beta": cfg.kd_beta,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None, None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
use_batch_sampler_collator = False
if is_eval is False and cfg.sample_packing:
use_batch_sampler_collator = True
if cfg.eval_sample_packing and is_eval:
use_batch_sampler_collator = True
if cfg.kd_online_server_base_url:
from .collator_online_teacher import OnlineTeacherCollator
return OnlineTeacherCollator, {
"kd_online_server_base_url": cfg.kd_online_server_base_url,
"kd_online_topk": cfg.kd_online_topk,
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq, {}
return DataCollatorForKD, {}
def pre_model_load(self, cfg):
from .kernels.models import apply_kernel
apply_kernel(cfg.model_config_type)
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds temp scheduler callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
callback = KDTemperatureSchedulerCallback(
cfg.kd_temperature,
cfg.kd_temperature_min,
trainer,
)
return [callback]
return []

View File

@@ -15,9 +15,19 @@
"""
Plugin args for KD support.
"""
from typing import Optional
from dataclasses import dataclass
from enum import Enum
from pydantic import BaseModel
from pydantic import BaseModel, Field
class InferenceServerType(str, Enum):
"""
Online inferences server types to handle different request args
"""
vllm = "vllm" # pylint: disable=invalid-name
sglang = "sglang" # pylint: disable=invalid-name
class KDArgs(BaseModel):
@@ -25,13 +35,41 @@ class KDArgs(BaseModel):
Input args for knowledge distillation.
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[float] = (
kd_trainer: float | None = None # whether to use KD trainer
kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = 0.0 # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: bool | None = (
None # whether to normalize student logits during KD
)
# TODO online kd
kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None
kd_online_server: InferenceServerType | None = Field(
default_factory=lambda: InferenceServerType.vllm
)
kd_online_timeout: int | None = 120
kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd
)
@dataclass
class KDTrainingArgsMixin:
"""
Additional args for KD training.
"""
kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
)

Some files were not shown because too many files have changed in this diff Show More