Compare commits

..

52 Commits

Author SHA1 Message Date
Wing Lian
1a229b0901 add colab callback to fix inference post train 2025-05-05 16:40:01 -04:00
Wing Lian
72ece3dadf support for configurable group and bin size for sample packing 2025-05-05 06:57:21 -04:00
Wing Lian
2e74e1d289 fix xformers inference 2025-05-05 06:20:31 -04:00
Wing Lian
4f478083e7 fix batch size setter 2025-05-05 03:49:01 -04:00
Wing Lian
82453bab7e handle xformers patch for inference too 2025-05-05 03:22:02 -04:00
Wing Lian
5b2bd75aba parallel bin packing
fix error with lambda and pickling

make sure things are in float instead of np.float
2025-05-04 23:24:46 -04:00
Wing Lian
03508c6816 improve readability of multipack sampler 2025-05-04 23:24:40 -04:00
Wing Lian
48b3e14a24 Print axolotl art if train is called outside of cli: 2025-05-04 23:24:35 -04:00
Wing Lian
544b1212d8 use relative import 2025-05-04 07:36:26 -04:00
Wing Lian
695fc2f802 missing __init__ 2025-05-04 07:31:01 -04:00
Wing Lian
c7f38ba96b fix seq lens calc to drop hanging sequences 2025-05-03 21:56:45 -04:00
Wing Lian
372fd08548 fix fp16 / bf16 reset when using fp16 with bf16 auto 2025-05-03 21:56:39 -04:00
Wing Lian
52cab2aa5b refactor so we can add test 2025-05-03 21:55:34 -04:00
Wing Lian
bed8f354a5 reorder the packing check 2025-05-03 15:38:29 -04:00
Wing Lian
f301a165c3 fix xformers + packing validation 2025-05-03 15:00:33 -04:00
Wing Lian
2b3a09aeae wire up the patch 2025-05-03 15:00:29 -04:00
Wing Lian
648780de51 xformers attention with packing 2025-05-03 14:59:49 -04:00
Wing Lian
ecc2388274 chunked cross entropy loss 2025-05-03 14:59:43 -04:00
Wing Lian
ebf724a9d9 fix import 2025-05-03 12:03:15 -04:00
Wing Lian
99095573c3 add tabs back to code check 2025-05-03 12:03:15 -04:00
Wing Lian
140083a828 patch peft to not upcast everything 2025-05-03 12:03:15 -04:00
Wing Lian
37c27aedc1 fsdp embeddings should be float32 per comment 2025-05-03 12:03:15 -04:00
Wing Lian
ed922796b7 include multipack support for qwen3 family (#2622) 2025-05-03 12:02:39 -04:00
Wing Lian
3dd9c3bf3f setup hf transfer too and fix auto bf16 when fp16 enabled (#2620) [skip ci] 2025-05-03 12:02:26 -04:00
Wing Lian
0ba7d362fa qwen3 and qwen3_moe support for liger kernels (#2612)
* qwen3 and qwen3_moe support for liger kernels

* fix moe module path

* fix: qwen3 liger input args and mlp

* fix: qwen3 input args and output class

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-02 09:29:55 -04:00
aitechguy
e4f73bc98e remove keys to incoporate changes for the trl update (#2616) 2025-05-02 08:47:42 -04:00
Wing Lian
bcb59c70e2 automatically set pad_to_sequence_len when use packing (#2607)
* automatically set pad_to_sequence_len when use packing

* update tests
2025-05-01 13:24:38 -04:00
NanoCode012
6a3e6f8c53 fix: run preview-docs only when md/qmd changes (#2606)
* fix: run preview-docs only when md/qmd changes

* feat: add quarto yaml based on PR feedback
2025-05-01 13:21:28 -04:00
Wing Lian
fee3c13bb5 Logging config for colab (#2611)
* only configure logging on cli to play nicely with colab

* allow reloading the config on the fly from a dict

* make sure to use dict for yaml

* reuse existing function for load

* make cli args optional

* mps fix and respect max_steps
2025-05-01 12:58:00 -04:00
Rahul Tuli
996fc124e5 Add: Sparse Finetuning Integration with llmcompressor (#2479)
* Add: SFTPlugin with llmcompressor

* Update: review comments!

* Add:llmcompressor instalable

* pre commit hooks

* Use: warning over warn

* Revert: TODO's

* Update llmcompressor version to latest

* Apply suggestions from @markurtz

Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>

* Address review comments from @markurtz

* Add: llcompressor installable

* Rename: sft.yaml to sparse-finetuning.yaml

* Use: absolute import

* Update model config

* Move: LLMCompressorPlugin into it's own submodule

* Add: `llm_compressor` integration documentation

* Rebase and updates!

* Tests, Style, Updates

* Add: .qmd file

* Address Review Comments:
* deleted redundant docs/llm_compressor.qmd
* incorporated feedback in integration README.md
* added llmcompressor integration to docs/custom_integrations.qmd

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* Add: line about further optimizations using llmcompressor

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* Apply patch from @winglian

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* Fix: Test

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* additional fixes for docker and saving compressed

* split llmcompressor from vllm checks

* Reset session between tests

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* move decorator to test method instead of class

* make sure to reset the session after each test

* move import of llmcompressor to reset session inside test

---------

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-01 12:25:16 -04:00
Wing Lian
e963990ad7 add missing __init__ for lr monkeypatch fix (#2609) 2025-05-01 09:41:32 -04:00
Dhruv Mullick
c3f2b1c5c2 Add num_completions_to_print for trl and grpo (#2604) 2025-04-30 21:00:30 -04:00
Wing Lian
6ba5c0ed2c use latest hf-xet and don't install vllm for torch 2.7.0 (#2603)
* use latest hf-xet and don't install vllm for torch 2.7.0

* fix runpod hub tests
2025-04-30 18:27:39 -04:00
Wing Lian
24ff5f53f8 additional args for grpo config/trainer (#2598) 2025-04-30 13:11:12 -04:00
Wing Lian
5e949eaa07 replace zero_only with simpler if statement (#2592) 2025-04-30 13:11:03 -04:00
Wing Lian
89ca14d9a0 ensure we pass axolotl extras to the Dockerfile so vllm is included in shipped images (#2599) 2025-04-30 11:35:45 -04:00
Wing Lian
8446b4ad28 don't automatically enable lora kernels for RL training (#2600) 2025-04-30 11:06:50 -04:00
Wing Lian
fc79606b6d only import vllm serve cli if its being called (#2597) [skip ci] 2025-04-30 09:11:25 -04:00
Wing Lian
baeb00231b Handle other reasoning trace dataset formats (#2591)
* Handle other reasoning trace dataset formats

* rename var to improve readability

* chore: refactor with comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-04-30 03:32:55 -04:00
Wing Lian
2413688b08 upload the deepspeed json to wandb (#2593) [skip ci] 2025-04-30 03:32:44 -04:00
NanoCode012
5bb1f3da56 feat: add qwen3 moe block for ds3 (#2596) [skip ci] 2025-04-30 03:32:23 -04:00
Wing Lian
a21b9cc472 patch to convert LR from tensor to float when using DS (#2595) [skip ci] 2025-04-30 03:31:57 -04:00
Aleksandr Dremov
41a1ec0c95 Plugins create_lr_scheduler support (#2584)
* lr_scheduler support

* fix

* Update scheduler.py

* Update scheduler.py

* cfg handling

* black

* remove debug

* remove adding the axolotl cfg to the scheduler mixin

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-29 17:08:30 -04:00
Dan Saunders
ecac731922 auto-enable lora kernels where possible (#2589)
* auto-enable lora kernels where possible

* test

* revert change to example yaml

* naming

* remove print

* slight logic change
2025-04-29 16:18:49 -04:00
NanoCode012
742fef4200 fix(doc): key used to point to url in multimodal doc (#2575) [skip ci] 2025-04-29 15:10:59 -04:00
Wing Lian
a39caf8824 bump vllm==0.8.5 for qwen3 support (#2583) [skip ci] 2025-04-29 15:10:40 -04:00
Wing Lian
07e4f2e25b support for qwen3 with lora kernels (#2588)
* support for qwen3 with lora kernels

* fix patch

* typo
2025-04-29 15:02:49 -04:00
Dan Saunders
c7d07de6b4 Fix eval + add smoke test (#2586)
* fix evaluate CLI

* add smoke test

* fix naming

* lint
2025-04-29 12:58:54 -04:00
Wing Lian
6565ae85d8 set config on the PluginManager for callback access (#2587) 2025-04-29 12:05:44 -04:00
Wing Lian
80b4edb4a7 Post release fixes (#2581)
* fix missing kwarg on child

* make the runpod test shorter

* update docs

* rename runpod test json file

* typing fixes and ordering of doc
2025-04-29 10:01:38 -04:00
Wing Lian
fedbcc0254 remove torch 2.4.1 CI as part of support deprecation (#2582) 2025-04-29 08:28:32 -04:00
Wing Lian
8175896ada add dev tag for v0.10.0.dev0 (#2580) 2025-04-28 20:30:14 -04:00
75 changed files with 2618 additions and 453 deletions

View File

@@ -22,12 +22,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""

View File

@@ -15,11 +15,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -35,7 +30,7 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -67,6 +62,7 @@ jobs:
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
@@ -82,11 +78,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -9,6 +9,7 @@ on:
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -32,13 +33,6 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras: # no vllm support for 2.4.1
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -12,11 +12,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -70,11 +65,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -4,6 +4,12 @@ on:
pull_request:
types: [opened, synchronize, reopened]
# Run the workflow only when one of these files changes
paths:
- '**/*.md' # any Markdown file
- '**/*.qmd' # any Quarto file
- '_quarto.yaml'
permissions:
checks: write
contents: write

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -106,13 +106,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -27,6 +27,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
TRANSFORMERS_IS_CI: "yes"
jobs:
pre-commit:
name: pre-commit
@@ -49,7 +52,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -135,7 +138,7 @@ jobs:
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -258,6 +261,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

86
.runpod/test-input.json Normal file
View File

@@ -0,0 +1,86 @@
{
"input": {
"name": "quick_smoke_test_sft",
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
},
"timeout": 100000
},
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -1,65 +1,70 @@
{
"input": {
"name": "quick_smoke_test_sft",
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
"tests": [
{
"name": "quick_smoke_test_sft",
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
}
},
"timeout": 100000
},
},
"timeout": 100000
}
],
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,

View File

@@ -184,6 +184,10 @@ datasets:
# adding a system turn with empty content.
drop_system_message:
# Optional[bool]. Whether to split the assistant turn based on a reasoning trace inside delimited tags
# defaults to False
split_thinking:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`

View File

@@ -49,7 +49,8 @@ sections = [
("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum")
("Spectrum", "spectrum"),
("LLMCompressor", "llm_compressor")
]
for section_name, folder_name in sections:

View File

@@ -164,7 +164,7 @@ Here is an example of a multi-modal dataset:
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
},

View File

@@ -0,0 +1,77 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -18,7 +18,7 @@ accelerate==1.6.0
datasets==3.5.0
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.0.0
hf_xet==1.1.0
hqq==0.2.5
optimum==1.16.2

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.4"]
extras_require_map["vllm"] = ["vllm==0.8.5"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.4"]
extras_require_map["vllm"] = ["vllm==0.8.5"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -149,6 +149,9 @@ extras_require = {
"vllm": [
"vllm==0.7.2",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.9.0"
__version__ = "0.10.0.dev0"

View File

@@ -2,4 +2,7 @@
import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
configure_logging()

View File

@@ -16,8 +16,15 @@ AXOLOTL_LOGO = """
@@@@ @@@@@@@@@@@@@@@@
"""
HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement
if HAS_PRINTED_LOGO:
return
if is_main_process():
HAS_PRINTED_LOGO = True
print(AXOLOTL_LOGO)

View File

@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -5,6 +5,7 @@ import logging
import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from urllib.parse import urlparse
@@ -152,7 +153,15 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
@@ -164,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
Returns:
`DictDefault` mapping configuration keys to values.
"""
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
if isinstance(config, (str, Path)):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
else:
cfg = config
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
temp_file.write(yaml.dump(config.to_dict()))
temp_file.close()
cfg.axolotl_config_path = temp_file.name
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
@@ -184,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -213,5 +231,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
return cfg

View File

@@ -1,6 +1,7 @@
"""CLI to run evaluation on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -14,6 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -29,10 +31,14 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -28,9 +28,8 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.cli.vllm_serve import do_vllm_serve
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils import patch_optimized_env
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -56,6 +55,8 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -101,7 +102,7 @@ def train(
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
patch_optimized_env()
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
@@ -327,6 +328,8 @@ def fetch(directory: str, dest: Optional[str]) -> None:
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
from axolotl.cli.vllm_serve import do_vllm_serve
do_vllm_serve(config, cli_args)

View File

@@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -36,7 +36,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
patch_optimized_env()
print_axolotl_text_art()
check_accelerate_default_config()

View File

@@ -20,11 +20,9 @@ from transformers import (
ProcessorMixin,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -11,5 +11,6 @@ MOE_ARCH_BLOCK = {
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
}

View File

@@ -47,7 +47,8 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
def load_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
@@ -56,6 +57,7 @@ def load_datasets(
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
Returns:
Dataclass with fields for training and evaluation datasets and the computed
@@ -64,7 +66,8 @@ def load_datasets(
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
hasattr(cli_args, "iterable")
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
@@ -76,20 +79,25 @@ def load_datasets(
preprocess_iterable=preprocess_iterable,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
num_examples=num_examples,
text_only=text_only,
)
LOG.info("printing prompters...")

View File

@@ -21,6 +21,7 @@ import importlib.util
import inspect
import logging
import math
import os
import sys
from abc import abstractmethod
from pathlib import Path
@@ -60,6 +61,7 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
@@ -71,6 +73,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
@@ -114,6 +117,8 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@@ -290,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -485,7 +494,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
self.cfg.max_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = (

View File

@@ -114,6 +114,8 @@ class AxolotlTrainer(
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res

View File

@@ -63,6 +63,7 @@ class GRPOStrategy:
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
@@ -70,6 +71,13 @@ class GRPOStrategy:
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.loss_type is not None:
grpo_args_kwargs["loss_type"] = trl.loss_type
if trl.mask_truncated_completions is not None:
grpo_args_kwargs["mask_truncated_completions"] = (
trl.mask_truncated_completions
)
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
@@ -85,6 +93,11 @@ class GRPOStrategy:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
if trl.epsilon_high is not None:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
return grpo_args_kwargs

View File

@@ -3,9 +3,10 @@
import logging
import torch
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -25,9 +26,9 @@ class SchedulerMixin(Trainer):
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
) -> LRScheduler:
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
@@ -47,7 +48,16 @@ class SchedulerMixin(Trainer):
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
trainer=self,
optimizer=optimizer,
num_training_steps=num_training_steps
)
if lr_scheduler is not None:
LOG.info(f"Using plugin-created lr_scheduler: {lr_scheduler}")
self.lr_scheduler = lr_scheduler
elif self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
@@ -110,4 +120,4 @@ class SchedulerMixin(Trainer):
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
return self.lr_scheduler # type: ignore

View File

@@ -1,6 +1,7 @@
"""Module for ReLoRA trainer"""
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -19,9 +20,11 @@ class ReLoRATrainer(AxolotlTrainer):
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
):
) -> LRScheduler:
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
if self.args.relora_steps:
warmup_steps = (
@@ -30,7 +33,7 @@ class ReLoRATrainer(AxolotlTrainer):
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
self.lr_scheduler = ReLoRAScheduler( # type: ignore
optimizer,
lr_scheduler,
self.args.relora_steps,
@@ -38,6 +41,6 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
self.lr_scheduler = lr_scheduler # type: ignore
return self.lr_scheduler
return self.lr_scheduler # type: ignore

View File

@@ -11,20 +11,19 @@ from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger("axolotl.evaluate")
LOG = get_logger(__name__)
def evaluate_dataset(
@@ -75,37 +74,22 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
Returns:
Dictionary mapping metric names to their values.
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
# Load tokenizer, processor and model
LOG.debug("loading model for evaluation...")
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# Get datasets
# pylint: disable=duplicate-code
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(cfg, tokenizer, processor=processor)
# Set up trainer
trainer = setup_trainer(
cfg,
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=(model, None, None), # No need for model_ref or peft_config
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,

View File

@@ -24,6 +24,7 @@ import logging
from typing import OrderedDict
import torch
from torch.optim.lr_scheduler import LRScheduler
class BasePlugin:
@@ -41,7 +42,7 @@ class BasePlugin:
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
"""
@@ -146,8 +147,8 @@ class BasePlugin:
"""
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
self, cfg, trainer, optimizer, num_training_steps
) -> LRScheduler | None: # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
@@ -155,9 +156,10 @@ class BasePlugin:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
object: The created learning rate scheduler.
object (LRScheduler): The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
@@ -270,6 +272,7 @@ class PluginManager:
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
_cfg = None
def __new__(cls):
"""
@@ -277,7 +280,9 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict()
cls._instance.plugins: OrderedDict[str, BasePlugin] = (
collections.OrderedDict()
)
return cls._instance
@staticmethod
@@ -290,6 +295,14 @@ class PluginManager:
PluginManager()
return PluginManager._instance # type: ignore
@property
def cfg(self):
return self._cfg
@cfg.setter
def cfg(self, cfg):
self._cfg = cfg
def register(self, plugin_name: str):
"""
Registers a new plugin by its name.
@@ -409,29 +422,29 @@ class PluginManager:
return trainer_cls
return None
def create_optimizer(self, cfg, trainer):
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer)
optimizer = plugin.create_optimizer(self.cfg, trainer)
if optimizer is not None:
return optimizer
return None
def create_lr_scheduler(self, cfg, trainer, optimizer):
def create_lr_scheduler(
self, trainer, optimizer, num_training_steps
) -> LRScheduler | None:
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
@@ -439,7 +452,12 @@ class PluginManager:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
self.cfg,
trainer=trainer,
optimizer=optimizer,
num_training_steps=num_training_steps,
)
if scheduler is not None:
return scheduler
return None

View File

@@ -25,7 +25,7 @@ import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import is_main_process
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -72,11 +72,11 @@ class CutCrossEntropyPlugin(BasePlugin):
if cfg.cut_cross_entropy:
self._check_requirements()
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
from .monkeypatch.patch import (
cce_patch,
)
with zero_only():
if is_main_process(use_environ=True):
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)

View File

@@ -37,6 +37,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
train_on_eos=None,
train_on_eot=None,
eot_tokens=None,
split_thinking: bool | None = False,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
@@ -54,6 +55,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
train_on_eos=train_on_eos,
train_on_eot=train_on_eot,
eot_tokens=eot_tokens,
split_thinking=split_thinking,
)
@property

View File

@@ -23,8 +23,8 @@ import logging
import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
@@ -85,7 +85,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
if is_main_process(use_environ=True):
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
@@ -151,6 +151,30 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
else:
logging.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."

View File

@@ -0,0 +1,160 @@
"""
Liger FLCE for Qwen3. Based on transformers v4.51.3.
"""
import sys
from typing import Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
if rms_norm:
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
if glu_activation:
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
if layer_norm:
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,191 @@
"""
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
"""
import sys
from copy import deepcopy
from typing import List, Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
if rms_norm:
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,108 @@
# LLMCompressor Integration
Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).
This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.
It uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.
---
## Requirements
- Axolotl with `llmcompressor` extras:
```bash
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`
This will install all necessary dependencies to fine-tune sparsified models using the integration.
---
## Usage
To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:
```yaml
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true
# ... (other training arguments)
```
This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.
Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -0,0 +1,5 @@
"""Integration entry point for the LLMCompressor plugin."""
from .plugin import LLMCompressorPlugin
__all__ = ["LLMCompressorPlugin"]

View File

@@ -0,0 +1,40 @@
"""
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]

View File

@@ -0,0 +1,171 @@
"""
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from torch.nn import Module
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the Sparse Finetuning callback handler.
Args:
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.original_compute_loss = trainer.compute_loss
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
active_session().finalize()
self.trainer.compute_loss_func = self.original_compute_loss
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the path to the plugin's argument definition.
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(
compute_loss_func: Callable[Concatenate[Module, P], R],
) -> Callable[Concatenate[Module, P], R]:
"""
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (Callable): Original loss computation function.
Returns:
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
@wraps(compute_loss_func)
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(model, *args, **kwargs)
if active_session().lifecycle.initialized_ and model.training:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -0,0 +1,40 @@
"""Utilities for llmcompressor integration with axolotl."""
from typing import Union
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from transformers import PreTrainedModel, Trainer
def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
Synchronize processes, apply compression hooks, and save the model.
Args:
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
# Only the main process writes the files
if not trainer.accelerator.is_main_process:
return
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -0,0 +1,19 @@
"""
attention module for attention monkeypatches
"""
from transformers.integrations.flash_attention import flash_attention_forward
def patch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .xformers import xformers_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
def unpatch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()

View File

@@ -12,10 +12,8 @@ import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__)

View File

@@ -0,0 +1,160 @@
"""
xformers attention implementation for packing
"""
from typing import Optional
import torch
import xformers
import xformers.ops.fmha
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
xformers_attention = xformers.ops.fmha.memory_efficient_attention
def xformers_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
dropout: float = 0.0, # pylint: disable=unused-argument
scaling: Optional[float] = None, # pylint: disable=unused-argument
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None, # pylint: disable=unused-argument
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Get dimensions
# query: [batch, heads, seq_len, hidden_dim]
batch_size = query.size(0)
query_length = query.shape[2]
key_length = key.shape[2]
# Default causal mask
attn_bias = xformers.ops.LowerTriangularMask()
# Check if we have sliding window attention
has_sliding_window = sliding_window is not None and sliding_window < query_length
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Get GQA parameters
num_attention_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
head_dim = query.size(-1)
is_gqa = num_attention_heads != num_key_value_heads
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
if cu_seq_lens_q is None or cu_seq_lens_k is None:
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
cu_seq_lens_q = cu_seq_lens_q.squeeze()
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
attn_bias = (
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths.tolist(),
)
)
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
elif attention_mask is not None:
query, key, value, _, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
seq_lengths = []
for i in range(len(cu_seq_lens_q) - 1):
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths,
kv_seqlen=seq_lengths,
)
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
else:
# Handle Group Query Attention (GQA) using view/expand approach from reference
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
key = key.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
value = value.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
if module.training:
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
if has_sliding_window:
query = query.view(
1, batch_size * query_length, num_attention_heads, head_dim
)
key = key.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
value = value.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
else:
query = query.view(
batch_size, query_length, num_key_value_heads, n_groups, head_dim
)
# If we need a sliding window attention
if has_sliding_window:
query = query.view(
1,
batch_size * query_length,
num_key_value_heads,
n_groups,
head_dim,
)
key = key.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
value = value.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
# Run the xformers attention
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
)
return attn_output, None

View File

@@ -23,22 +23,42 @@ from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
QKV_PATCHES = [
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
"\n"
),
),
(
"""
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
),
]
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
@@ -128,10 +148,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
@@ -168,10 +189,18 @@ def patch_self_attn_lora(cfg: DictDefault):
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert any(
qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES
), "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
for qkv_orig, qkv_patched in QKV_PATCHES:
if qkv_orig in self_attn_forward:
self_attn_forward = self_attn_forward.replace(
qkv_orig,
qkv_patched,
)
break
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",

View File

View File

@@ -0,0 +1,134 @@
"""
chunked ce loss
"""
from typing import List, Optional
import torch
import torch.nn.functional as F
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
class CEWithChunkedOutputLoss(torch.nn.Module):
"""
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
"""
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
def compute_cross_entropy(
self,
logits: torch.Tensor,
labels: torch.Tensor,
normalize: bool = True, # pylint: disable=unused-argument
) -> torch.Tensor:
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
)
def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor:
"""
Args:
logits (List[torch.Tensor]): List of chunked logits of length
``self.num_output_chunks``, where each chunk has shape
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
reduction (str): The reduction to apply to the output.
Returns:
torch.Tensor: Cross entropy loss of shape (1,).
"""
total_elements = (labels != self.ignore_index).sum()
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
labels = [
target_chunk.reshape(-1)
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
]
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
logits = [
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
]
# compute one chunk at a time
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
if reduction == "sum":
return total_loss
return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor"
)
return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
def chunked_fix_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
logit_chunks = [ # pylint: disable=unnecessary-comprehension
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
]
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def for_causal_lm_chunked_loss(
logits,
labels,
vocab_size: int = None, # pylint: disable=unused-argument
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# skip the upcast to float since we handle that in the chunking loss
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Skip Flattening the tokens
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = chunked_fix_cross_entropy(
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
)
return loss
return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss
)

View File

@@ -18,6 +18,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"qwen2_moe",
"qwen3",
"qwen3_moe",
"falcon",
"phi",
"phi3",

View File

View File

@@ -0,0 +1,78 @@
"""
Patch prepare_model_for_kbit_training to not upcast everything
"""
import inspect
import logging
import peft
import axolotl
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_PREPARE_CODE = """
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)
"""
PATCHED_PREPARE_CODE = """
for name, param in model.named_parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit" and "norm" in name:
param.data = param.data.to(torch.float32)
"""
def get_peft_prep_code() -> str:
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
return prepare
def check_peft_prep_code_is_patchable() -> bool:
prep_code = get_peft_prep_code()
prep_code, _ = detab_code(prep_code)
return ORIGINAL_PREPARE_CODE in prep_code
def patch_peft_prep_code():
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
"""
try:
prep_code = get_peft_prep_code()
except OSError:
return
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
prep_code
)
prep_code, _ = detab_code(prep_code)
if ORIGINAL_PREPARE_CODE not in prep_code:
return
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
prep_code = prep_code.replace(
"def prepare_model_for_kbit_training(",
"def fixed_prepare_model_for_kbit_training(",
1,
)
items_to_import = []
for item in dir(peft.utils.other):
if item in prep_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
globals(),
)
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

View File

@@ -0,0 +1,42 @@
"""
monkeypatch for Trainer _get_learning_rate method
"""
import logging
import torch
LOG = logging.getLogger(__name__)
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
def _get_learning_rate(self):
if self.is_deepspeed_enabled:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
LOG.warning(
"tried to get lr value before scheduler/optimizer started stepping, returning lr=0"
)
last_lr = 0
else:
raise
else:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
last_lr = self.optimizer.param_groups[0]["lr"]
else:
last_lr = self.lr_scheduler.get_last_lr()[0]
if torch.is_tensor(last_lr):
last_lr = last_lr.item()
return last_lr
def patch_trainer_get_lr():
from transformers.trainer import Trainer
Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access

View File

@@ -4,7 +4,7 @@ HF Chat Templates prompt strategy
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
@@ -29,12 +29,12 @@ class ChatTemplatePrompter(Prompter):
chat_template: str,
processor=None,
max_length=2048,
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
message_property_mappings: Dict[str, str] | None = None,
message_field_training: str | None = None,
message_field_training_detail: str | None = None,
field_messages: str = "messages",
field_system: str = "system",
roles: Optional[Dict[str, List[str]]] = None,
roles: Dict[str, List[str]] | None = None,
drop_system_message: bool = False,
):
# check if message_property_mappings is None or empty dict
@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = {
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
}
if roles:
@@ -65,7 +66,7 @@ class ChatTemplatePrompter(Prompter):
self.field_messages = field_messages
self.field_system = field_system
self.tokenizer = tokenizer
self.processor: Optional[ProcessorMixin] = processor
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
@@ -224,11 +225,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
tokenizer,
train_on_inputs: bool,
sequence_len: int,
roles_to_train: Optional[List[str]] = None,
train_on_eos: Optional[str] = None,
train_on_eot: Optional[str] = None,
eot_tokens: Optional[List[str]] = None,
split_thinking: Optional[bool] = False,
roles_to_train: list[str] | None = None,
train_on_eos: str | None = None,
train_on_eot: str | None = None,
eot_tokens: list[str] | None = None,
split_thinking: bool | None = False,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
@@ -661,16 +662,46 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# if the role is assistant that we want to use reasoning_content
if self.split_thinking and transformed_message["role"] == "assistant":
content = transformed_message["content"]
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
for pair in pairs:
if pair[0] in content and pair[1] in content:
start_idx = content.find(pair[0])
end_idx = content.find(pair[1])
thinking_content = content[start_idx + len(pair[0]) : end_idx]
thinking_pairs = [
("<think>", "</think>"),
("<reasoning>", "</reasoning>"),
("<|begin_of_thought|>", "<|end_of_thought|>"),
]
content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")]
for tpair in thinking_pairs:
# check if the thinking pair is in the content
if tpair[0] in content and tpair[1] in content:
# find the start and end index of the thinking pair
t_start_idx = content.find(tpair[0])
t_end_idx = content.find(tpair[1])
# get the thinking content
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message["reasoning_content"] = thinking_content.strip()
transformed_message["content"] = content[
end_idx + len(pair[1]) :
].lstrip()
# take remainder of the content
# strip whitespace from beginning of the remainder (thinking tokens)
remainder = content[t_end_idx + len(tpair[1]) :].lstrip()
# check if the content pair is in the remainder
cpair_found = False
for cpair in content_pairs:
if cpair[0] in remainder and cpair[1] in remainder:
# find the start and end index of the content pair
c_start_idx = remainder.find(cpair[0])
c_end_idx = remainder.find(cpair[1])
# get the content content
content_content = remainder[
c_start_idx + len(cpair[0]) : c_end_idx
]
transformed_message["content"] = content_content.strip()
cpair_found = True
break
# else, the content is the remainder
if not cpair_found:
transformed_message["content"] = remainder
break
# Determine which keys in the original message were not mapped
@@ -714,7 +745,7 @@ class StrategyLoader:
self,
tokenizer,
cfg,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
ds_cfg: Union[Dict[str, Any], DatasetConfig] | None = None,
processor=None,
):
if ds_cfg is None:

View File

@@ -21,6 +21,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.cli.art import print_axolotl_text_art
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
@@ -30,7 +31,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
@@ -42,7 +42,6 @@ try:
except ImportError:
BetterTransformer = None
configure_logging()
LOG = get_logger(__name__)
@@ -296,8 +295,23 @@ def save_trained_model(
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""
@@ -503,6 +517,8 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
print_axolotl_text_art()
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,

View File

@@ -43,3 +43,12 @@ def set_pytorch_cuda_alloc_conf():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)
def patch_optimized_env():
"""
Patch environment variables to improve VRAM usage and increase download speed
"""
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import gc
import json
import logging
import os
import traceback
@@ -808,11 +809,44 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
if args.deepspeed:
try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
with NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
prefix="deepspeed_config_",
) as temp_file:
skip_upload = False
if isinstance(args.deepspeed, dict):
json.dump(args.deepspeed, temp_file, indent=4)
elif isinstance(args.deepspeed, str) and os.path.exists(
args.deepspeed
):
copyfile(args.deepspeed, temp_file.name)
else:
skip_upload = True
if not skip_upload:
artifact = wandb.Artifact(
f"deepspeed-config-{wandb.run.id}",
type="deepspeed-config",
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The DeepSpeed config has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving DeepSpeed config to WandB: {err}")
return control
@@ -834,3 +868,29 @@ class GCCallback(TrainerCallback):
):
torch.cuda.empty_cache()
gc.collect()
def colab_inference_post_train_callback(trainer: Trainer):
class ColabCallback(TrainerCallback):
"""Callback to prep model for inference on Google Colab"""
def __init__(self, cfg):
self.gpu_name = torch.cuda.get_device_name(0)
self.cfg = cfg
def on_train_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
"""
handle T4 gpu, we need to convert attention to eager for inference
"""
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
trainer.model.eval()
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
trainer.model.gradient_checkpointing_disable()
trainer.model.config.use_cache = True
trainer.model.eval()
return ColabCallback

View File

@@ -59,7 +59,7 @@ def choose_device(cfg):
def resolve_dtype(cfg):
if (
cfg.bf16 == "auto" and not cfg.use_ray
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
): # if we use ray we want to defer this check to the worker node
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
@@ -67,9 +67,12 @@ def resolve_dtype(cfg):
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.fp16 is None:
if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True
if cfg.fp16 and cfg.bf16 == "auto":
cfg.bf16 = False
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False

View File

@@ -69,17 +69,27 @@ def barrier():
dist.barrier()
def is_main_process():
def is_main_process(use_environ=False):
"""
Check if the current process is the main process. If not in distributed mode,
always return `True`.
Args:
- use_environ (bool, optional): Use environment variable to determine main process.
Returns:
- bool: `True` if the current process is the main process, `False` otherwise.
"""
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
if not is_distributed():
return True
return dist.get_rank() == 0
def is_local_main_process():
def is_local_main_process(use_environ=False):
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
return PartialState().is_local_main_process
@@ -99,17 +109,6 @@ def cleanup_distributed():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager
def zero_first(is_main):
"""

View File

@@ -68,7 +68,7 @@ from axolotl.utils.distributed import (
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
@@ -141,6 +141,22 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
@@ -437,7 +453,7 @@ def load_tokenizer(cfg):
{"additional_special_tokens": additional_special_tokens}
)
with zero_only():
if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
@@ -540,11 +556,30 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.xformers_attention and self.cfg.sample_packing:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.chunked_cross_entropy:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
else:
patch_chunked_ce_loss_fn()
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
if self.cfg.adapter:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
@@ -1164,7 +1199,7 @@ class ModelLoader:
],
)
def prepare_model(self, qlora_fsdp) -> None:
def prepare_model(self, qlora_fsdp: bool) -> None:
skip_prepare_model_for_kbit_training = False
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -1293,7 +1328,7 @@ class ModelLoader:
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
if not self.cfg.fsdp:
if self.cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
self.convert_embedding_modules_dtype(
embedding_modules,

View File

@@ -1,10 +1,13 @@
# pylint: skip-file
"""
Multipack Batch Sampler
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
into fixed-capacity batches to optimize memory usage and training throughput.
"""
import logging
import math
from typing import Any, Iterable, List, Union
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Iterable, List, Union
import numba
import numpy as np
@@ -13,26 +16,39 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
"""
First-fit-decreasing bin packing algorithm check
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
Checks if sequences with the given lengths could fit in the specified number of bins
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
not_found = True
for idx in range(n):
for idx in range(num_bins):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
# If no bin could fit this sequence, packing failed
if not_found:
return False
@@ -40,240 +56,380 @@ def ffd_check(a: np.ndarray, c: int, n: int):
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
def pack_group(
sequence_lengths: np.ndarray,
group_offset: int,
bin_capacity: int,
max_bins: int,
bin_size: int,
safe_mode: bool = True,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
indices = np.argsort(a)[::-1]
a = a[indices]
Args:
sequence_lengths: Array of sequence lengths
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
bin_size: Maximum number of sequences per bin
safe_mode: If True, use a more conservative packing approach
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True
for bin_idx, _ in enumerate(bins_remaining_space):
if (
bins_remaining_space[bin_idx] >= size
and len(bins_assigned_sequences[bin_idx]) < bin_size
):
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
# Create a new bin if needed and if we haven't reached the limit
if add_new_bin:
if len(bins_remaining_space) >= max_bins and safe_mode:
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([global_idx])
return bins_result
# Safety check to avoid infinite bins
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
# Define a standalone function for multiprocessing
def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
return pack_group(
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
)
def pack_parallel(
sequence_lengths: np.ndarray,
bin_capacity: int,
group_size: int,
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
"""
Pack sequences into bins using parallel processing
s = 0
start_index = 0
result = []
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin as total number of tokens
group_size: Number of sequences to process in each group
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# Create tasks for parallel processing
tasks = []
for i in range(0, num_items, group_size):
group_lengths = sequence_lengths[i : i + group_size]
max_bins = len(group_lengths) # Allow as many bins as items in the group
tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
return all_bins
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
sequence_lengths: The lengths of all examples
rank: The current rank (for distributed training)
bin_capacity: The capacity of each bin (maximum sequence length)
num_ranks: Number of ranks (processes/GPUs)
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
rank_batches: List of batches for the current rank
total_tokens_used: Number of actual example tokens
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
rank_batches = []
total_tokens_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
current_bin = []
remaining_capacity = bin_capacity
for idx, size in enumerate(lengths):
# Process each sequence in order
for idx, size in enumerate(sequence_lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_used += size
total_tokens_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
total_used += size
remaining_capacity = bin_capacity - size
total_tokens_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
# Assign bins to ranks - each rank gets every num_ranks-th bin
for bin_idx in range(rank, len(all_bins), num_ranks):
rank_batches.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
return rank_batches, total_tokens_used, len(all_bins) * bin_capacity
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
"""
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop incomplete batches
num_count_samples: int = 16, # Number of samples to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.lengths = np.array(lengths, dtype=np.int32)
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
self.group_size = group_size
self.bin_size = bin_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
# Efficiency statistics tracking
self.total_tokens_used = 0
self.total_token_slots = 0
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
# The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
LOG.warning(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler]
"""
Generate packed batches for training
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
Args:
set_stats: Whether to update efficiency statistics
if self.sequential:
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
)
else:
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
if self._batches is not None:
return self._batches
batches = [
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
# Get indices from the sampler
indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
# Get lengths of the selected sequences
lengths = self.lengths[indices]
# Pack sequences into bins using either sequential or parallel packing
if self.sequential:
bins, total_used, total_slots = allocate_sequentially(
lengths,
rank=0,
bin_capacity=self.batch_max_len,
num_ranks=1,
)
else:
# Use parallel packing
all_bins = pack_parallel(
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
)
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins)
batches = [
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
]
# Drop last batch if requested and it's incomplete
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested
if set_stats:
self.total_tokens_used += total_used
self.total_token_slots += total_slots
self._batches = batches
return batches
def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks
# Truncate batches to ensure all ranks have the same number of batches
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
def gather_efficiency(self):
"""
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return math.floor(0.997 * max(estimates))
# Use 99.7% of max observed efficiency as a safe estimate
max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
if not self.len_across_ranks:
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
"""
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks

View File

@@ -242,6 +242,9 @@ class AxolotlInputConfig(
unsloth_rms_norm: bool | None = None
unsloth_rope: bool | None = None
chunked_cross_entropy: bool | None = None
chunked_cross_entropy_num_chunks: int | None = None
lora_mlp_kernel: bool | None = None
lora_qkv_kernel: bool | None = None
lora_o_kernel: bool | None = None
@@ -435,16 +438,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before")
@classmethod
# pylint: disable=duplicate-code
@@ -512,10 +505,17 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if data.get("sample_packing"):
pad_to_sequence_len = data.get("pad_to_sequence_len")
if pad_to_sequence_len is False:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")
@@ -1150,6 +1150,18 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_grpo_peft_liger(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("adapter")
):
raise ValueError("PEFT + GRPO + Liger is not yet supported")
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
@@ -1315,6 +1327,57 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):
# Only proceed if using LoRA or QLoRA adapter
if data.get("rl"):
# RL trainers not tested so don't enable kernels by default
return data
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
or data.get("adapter") == "lora"
and data.get("load_in_8bit")
):
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
is_fsdp = data.get("fsdp") is not None
is_fsdp2 = (
data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if (
not is_multi_gpu
or (is_multi_gpu and not is_fsdp)
or (is_multi_gpu and is_fsdp2)
):
# Auto-enable kernels if not explicitly set by user
if data.get("lora_mlp_kernel") is None:
data["lora_mlp_kernel"] = True
if data.get("lora_qkv_kernel") is None:
data["lora_qkv_kernel"] = True
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
+ "See https://docs.axolotl.ai/docs/lora_optims.html for more info."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):

View File

@@ -67,6 +67,12 @@ class TRLConfig(BaseModel):
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
num_completions_to_print: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
},
)
sync_ref_model: bool | None = Field(
default=False,
json_schema_extra={
@@ -133,3 +139,25 @@ class TRLConfig(BaseModel):
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)
epsilon_high: float | None = Field(
default=None,
json_schema_extra={
"description": "Upper-bound epsilon value for clipping in the GRPO algorithm."
},
)
use_liger_loss: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use Liger loss for GRPO."},
)
loss_type: str | None = Field(
default=None,
json_schema_extra={
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
},
)
mask_truncated_completions: bool = Field(
default=False,
json_schema_extra={
"description": "When enabled, truncated completions are excluded from the loss calculation."
},
)

View File

@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg):

View File

@@ -72,7 +72,7 @@ class LogHooksPlugin(BasePlugin):
f.write("get_trainer_cls\n")
def create_lr_scheduler(
self, cfg, trainer, optimizer
self, cfg, trainer, optimizer, num_training_steps
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -172,7 +172,7 @@ class TestPluginHooks:
assert "post_model_load" in file_contents
# assert "create_optimizer" in file_contents # not implemented yet
assert "get_trainer_cls" in file_contents
# assert "create_lr_scheduler" in file_contents # not implemented yet
assert "create_lr_scheduler" in file_contents
assert "add_callbacks_pre_trainer" in file_contents
assert "add_callbacks_post_trainer" in file_contents
assert "post_train" in file_contents

View File

@@ -0,0 +1,111 @@
"""
E2E smoke tests for LLMCompressorPlugin integration
"""
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import (
check_model_output_exists,
require_llmcompressor,
require_torch_2_4_1,
)
MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
]
@pytest.mark.parametrize(
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
)
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_llmcompressor
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
from llmcompressor import active_session
# core cfg
cfg = DictDefault(
{
"base_model": base_model,
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {
"recipe": {
"finetuning_stage": {
"finetuning_modifiers": {
"ConstantPruningModifier": {
"targets": [
"re:.*q_proj.weight",
"re:.*k_proj.weight",
"re:.*v_proj.weight",
"re:.*o_proj.weight",
"re:.*gate_proj.weight",
"re:.*up_proj.weight",
"re:.*down_proj.weight",
],
"start": 0,
},
},
},
},
"save_compressed": save_compressed,
},
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
try:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
finally:
active_session().reset()
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
if save_compressed:
assert (Path(temp_dir) / "recipe.yaml").exists()
from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig
compressor = ModelCompressor.from_pretrained(temp_dir)
assert compressor is not None
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)

View File

@@ -2,14 +2,19 @@
# pylint: disable=redefined-outer-name
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.state import PartialState
from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
from axolotl.cli.config import load_cfg
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
@@ -66,29 +71,36 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
@pytest.mark.parametrize(
"model_name,attention_cls",
[
("HuggingFaceTB/SmolLM2-135M", LlamaAttention),
("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention),
],
)
def test_attention_patching_integration(model_name, attention_cls):
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
cfg = {"base_model": model_name}
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
original_forward = getattr(attention_cls, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
patched_forward = attention_cls.forward
# Check the forward method was replaced
assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored
assert hasattr(LlamaAttention, "_original_forward")
assert hasattr(attention_cls, "_original_forward")
# Clean up
setattr(LlamaAttention, "forward", original_forward)
delattr(LlamaAttention, "_original_forward")
setattr(attention_cls, "forward", original_forward)
delattr(attention_cls, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model):
@@ -413,3 +425,42 @@ def test_kernel_training_integration():
# Verify correct activation function
layer = model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
def test_kernel_training_integration_auto_enable(temp_dir):
"""Test model loading with auto-enabled kernel patches."""
# Create minimal config without explicitly setting kernel options
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Verify kernel options were auto-enabled in the config
assert cfg.lora_mlp_kernel is True
assert cfg.lora_qkv_kernel is True
assert cfg.lora_o_kernel is True

View File

@@ -0,0 +1,65 @@
"""E2E smoke test for evaluate CLI command"""
import os
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
os.environ["WANDB_DISABLED"] = "true"
class TestE2eEvaluate:
"""Test cases for evaluate CLI"""
def test_evaluate(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.evaluate",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -105,7 +105,25 @@ def require_vllm(test_case):
return False
return unittest.skipUnless(
is_vllm_installed(), "test requires a vllm to be installed"
is_vllm_installed(), "test requires vllm to be installed"
)(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
)(test_case)

View File

@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"pad_to_sequence_len": False,
"flash_attention": True,
}
)
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
for record in self._caplog.records
)
def test_packing_autoset(self, minimal_cfg):
cfg = (
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
in record.message
for record in self._caplog.records
)
assert cfg.pad_to_sequence_len is True
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.

View File

@@ -34,7 +34,31 @@ def messages_w_reasoning_fixture():
"content": "<think>lorem</think>\nwelcome",
},
]
}
},
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "<|begin_of_thought|>lorem<|end_of_thought|>\n<|begin_of_solution|>welcome\n<|end_of_solution|>",
},
]
},
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "<reasoning>lorem</reasoning>\nwelcome",
},
]
},
]
)
@@ -83,36 +107,37 @@ class TestSplitThinking:
}
),
)
transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
assert transformed_prompt[0]["role"] == "user"
assert transformed_prompt[1]["role"] == "assistant"
assert transformed_prompt[1]["reasoning_content"] == "lorem"
assert transformed_prompt[1]["content"] == "welcome"
for conversation in messages_w_reasoning:
transformed_prompt = strategy.get_conversation_thread(conversation)
assert transformed_prompt[0]["role"] == "user"
assert transformed_prompt[1]["role"] == "assistant"
assert transformed_prompt[1]["reasoning_content"] == "lorem"
assert transformed_prompt[1]["content"] == "welcome"
res = strategy.tokenize_prompt(messages_w_reasoning[0])
input_ids = res["input_ids"]
# fmt: off
expected_input_ids = [
151644, # im_start
872, # user
198, # \n
14990, # hello
151645, # im_end
198, # \n
151644, # im_start
77091, # assistant
198, # \n
151667, # think
198, # \n
385, 1826, # lorem
198, # \n
151668, # /think
271, # \n
34084, # welcome
151645, # im_end
198, # \n
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
res = strategy.tokenize_prompt(conversation)
input_ids = res["input_ids"]
# fmt: off
expected_input_ids = [
151644, # im_start
872, # user
198, # \n
14990, # hello
151645, # im_end
198, # \n
151644, # im_start
77091, # assistant
198, # \n
151667, # think
198, # \n
385, 1826, # lorem
198, # \n
151668, # /think
271, # \n
34084, # welcome
151645, # im_end
198, # \n
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"

View File

@@ -0,0 +1,40 @@
"""
test suite for chunked cross entropy
"""
import pytest
import torch
from torch import nn
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
@pytest.fixture
def chunked_fixtures():
model_dim = 512
vocab_size = 1024 * 256
seq_len = 2048
batch_size = 1
lm_head = nn.Linear(model_dim, vocab_size)
hidden_state = torch.randn(batch_size, seq_len, model_dim)
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
return lm_head, hidden_state, labels, vocab_size
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
lm_loss = get_causal_lm_loss()
logits = lm_head(hidden_state)
chunked_lm_loss = lm_loss(logits, labels)
logits_flattened = logits.view(-1, vocab_size)
labels_flattened = labels.view(-1)
loss = nn.functional.cross_entropy(
logits_flattened.float(), labels_flattened, reduction="mean"
)
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)