Compare commits

..

47 Commits

Author SHA1 Message Date
Wing Lian
8cda9e93c1 set version for v0.9.1
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-05-07 16:10:51 -04:00
Wing Lian
17d715c2b3 swap tinymodels that have safetensors for some ci tests (#2641) 2025-05-07 16:10:18 -04:00
xzuyn
f943306263 Add CAME Optimizer (#2385) 2025-05-07 16:10:17 -04:00
NanoCode012
3c8b9b33d6 fix(doc): clarify instruction to delinearize llama4 similar to cli doc (#2644) [skip ci] 2025-05-07 16:10:17 -04:00
NanoCode012
8b0c2a71ad Fix: improve error message on failed dataset load (#2637) [skip ci]
* fix(log): clarify error on dataset loading failed

* fix: add path for easy tracking of broken config

* fix: improve error message based on pr feedback
2025-05-07 16:10:17 -04:00
Wing Lian
493910559a Configurable embeddings upcast (#2621)
* fsdp embeddings should be float32 per comment

* patch peft to not upcast everything

* add tabs back to code check

* fix import

* add configurable option and fix check

* add check for dtypes

* move embeddings test to patch dir

* fix test

* fix comment and logic
2025-05-07 16:10:16 -04:00
Eric Meier
c54534dbfa Fix cut_cross_entropy plugin install (#2642) [skip ci] 2025-05-07 16:10:16 -04:00
Wing Lian
cae5cebb59 xformers attention with packing (#2619)
* xformers attention with packing

* wire up the patch

* fix xformers + packing validation

* fix warning

* reorder the packing check

* fix fp16 / bf16 reset when using fp16 with bf16 auto

* fix seq lens calc to drop hanging sequences

* handle xformers patch for inference too

* fix batch size setter

* fix xformers inference

* add colab callback to fix inference post train

* PR feedback
2025-05-07 16:10:16 -04:00
Wing Lian
fcbd7477d0 Multipack parallel bin packing (#2631)
* improve readability of multipack sampler

* parallel bin packing
fix error with lambda and pickling

make sure things are in float instead of np.float

* annotations and comments update

* support for configurable group and bin size for sample packing

* fix missing map back to original indices
2025-05-07 16:10:15 -04:00
Wing Lian
038db85a40 allow plugins to return their own dataset (#2617) [skip ci]
* allow plugins to return their own dataset

* add post_trainer_create and wire up

* add hook check

* address PR feedback:

* remove annotation causing circular import
2025-05-07 16:10:15 -04:00
NanoCode012
680dcc5a4d feat(doc): add split_thinking docs (#2613) [skip ci]
* feat(doc): add split_thinking docs

* fix: link config.qmd to conversation.qmd for split_thinking example

* update thinking => reasoning_content in messages format

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-07 16:10:15 -04:00
Wing Lian
fed5ca8254 bump liger dep to 0.5.9 (#2640) [skip ci]
* bump liger dep to 0.5.9

* also upgrade vllm to post1, and datasets to 3.5.1
2025-05-07 16:10:15 -04:00
mhenrichsen
7a2d017c88 Update lr_scheduler options in config.qmd to include additional scheduling strategies for improved training flexibility. (#2636) [skip ci] 2025-05-07 16:10:15 -04:00
Wing Lian
8c0303aa5e Print axolotl art if train is called outside of cli: (#2627) [skip ci] 2025-05-07 16:10:14 -04:00
Wing Lian
5d61169f7c fix dpo eval override to call grandparent instead of the broken super (#2628) [skip ci] 2025-05-07 16:10:14 -04:00
Wing Lian
e1586f7919 make sure gc_steps is used for all trainers (#2638) 2025-05-07 16:10:14 -04:00
Wing Lian
e4bf3ffb17 repop cache (#2639)
* repop cache

* pre-cache as a step

* fix the name

* add reason for pytest skipif

* restore pytorch matrix

* remove max-parallel now that we've optimized this a bit
2025-05-07 16:10:14 -04:00
mhenrichsen
30150fe1e1 Adds example for training a TTS model on top of a LLM. (#2614)
* Adds example for training a TTS model on top of a LLM.

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update README.md to clarify GPU requirements for finetuning Orpheus TTS model

* Update finetune.yml to use the new base model canopylabs/orpheus-3b-0.1-pretrained

* Update finetune.yml and README.md for consistency and clarity

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-07 16:10:14 -04:00
Emmanuel Ferdman
7f7d7ade2e Fix logging deprecation warnings (#2623)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-07 16:10:14 -04:00
Wing Lian
776cf70fe4 include multipack support for qwen3 family (#2622) 2025-05-07 16:10:14 -04:00
Wing Lian
8730951aba setup hf transfer too and fix auto bf16 when fp16 enabled (#2620) [skip ci] 2025-05-07 16:10:13 -04:00
Wing Lian
e72c11ad55 qwen3 and qwen3_moe support for liger kernels (#2612)
* qwen3 and qwen3_moe support for liger kernels

* fix moe module path

* fix: qwen3 liger input args and mlp

* fix: qwen3 input args and output class

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-07 16:10:13 -04:00
aitechguy
1a7978b960 remove keys to incoporate changes for the trl update (#2616) 2025-05-07 16:10:13 -04:00
Wing Lian
60b0d14f1d automatically set pad_to_sequence_len when use packing (#2607)
* automatically set pad_to_sequence_len when use packing

* update tests
2025-05-07 16:10:13 -04:00
NanoCode012
a7a40378f5 fix: run preview-docs only when md/qmd changes (#2606)
* fix: run preview-docs only when md/qmd changes

* feat: add quarto yaml based on PR feedback
2025-05-07 16:10:13 -04:00
Wing Lian
b50d35bec9 Logging config for colab (#2611)
* only configure logging on cli to play nicely with colab

* allow reloading the config on the fly from a dict

* make sure to use dict for yaml

* reuse existing function for load

* make cli args optional

* mps fix and respect max_steps
2025-05-07 16:10:13 -04:00
Wing Lian
bc6dfa6899 add missing __init__ for lr monkeypatch fix (#2609) 2025-05-07 16:10:13 -04:00
Dhruv Mullick
9d6e8af622 Add num_completions_to_print for trl and grpo (#2604) 2025-05-07 16:10:12 -04:00
Wing Lian
17b441248c use latest hf-xet and don't install vllm for torch 2.7.0 (#2603)
* use latest hf-xet and don't install vllm for torch 2.7.0

* fix runpod hub tests
2025-05-07 16:10:12 -04:00
Wing Lian
d49a4268b8 additional args for grpo config/trainer (#2598) 2025-05-07 16:10:12 -04:00
Wing Lian
1d6e931115 replace zero_only with simpler if statement (#2592) 2025-05-07 16:10:12 -04:00
Wing Lian
ff106ace44 ensure we pass axolotl extras to the Dockerfile so vllm is included in shipped images (#2599) 2025-05-07 16:10:12 -04:00
Wing Lian
24907533d1 don't automatically enable lora kernels for RL training (#2600) 2025-05-07 16:10:12 -04:00
Wing Lian
0e9d816d2e only import vllm serve cli if its being called (#2597) [skip ci] 2025-05-07 16:10:12 -04:00
Wing Lian
72f142186a Handle other reasoning trace dataset formats (#2591)
* Handle other reasoning trace dataset formats

* rename var to improve readability

* chore: refactor with comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-07 16:10:11 -04:00
Wing Lian
87726322bf upload the deepspeed json to wandb (#2593) [skip ci] 2025-05-07 16:10:11 -04:00
NanoCode012
ae8ae7534c feat: add qwen3 moe block for ds3 (#2596) [skip ci] 2025-05-07 16:10:11 -04:00
Wing Lian
ee00142cb5 patch to convert LR from tensor to float when using DS (#2595) [skip ci] 2025-05-07 16:10:11 -04:00
Aleksandr Dremov
097e7e3b5b Plugins create_lr_scheduler support (#2584)
* lr_scheduler support

* fix

* Update scheduler.py

* Update scheduler.py

* cfg handling

* black

* remove debug

* remove adding the axolotl cfg to the scheduler mixin

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-07 16:10:11 -04:00
Dan Saunders
c714958181 auto-enable lora kernels where possible (#2589)
* auto-enable lora kernels where possible

* test

* revert change to example yaml

* naming

* remove print

* slight logic change
2025-05-07 16:10:11 -04:00
NanoCode012
4402c293dc fix(doc): key used to point to url in multimodal doc (#2575) [skip ci] 2025-05-07 16:10:10 -04:00
Wing Lian
0d71f787a3 bump vllm==0.8.5 for qwen3 support (#2583) [skip ci] 2025-05-07 16:10:10 -04:00
Wing Lian
c337ca0872 support for qwen3 with lora kernels (#2588)
* support for qwen3 with lora kernels

* fix patch

* typo
2025-05-07 16:10:10 -04:00
Dan Saunders
f04f7cf5ad Fix eval + add smoke test (#2586)
* fix evaluate CLI

* add smoke test

* fix naming

* lint
2025-05-07 16:10:10 -04:00
Wing Lian
c64a951bc9 set config on the PluginManager for callback access (#2587) 2025-05-07 16:10:10 -04:00
Wing Lian
fc88cc56cb Post release fixes (#2581)
* fix missing kwarg on child

* make the runpod test shorter

* update docs

* rename runpod test json file

* typing fixes and ordering of doc
2025-05-07 16:10:10 -04:00
Wing Lian
e85cbb8645 remove torch 2.4.1 CI as part of support deprecation (#2582) 2025-05-07 16:10:10 -04:00
36 changed files with 1358 additions and 2159 deletions

View File

@@ -329,12 +329,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -371,43 +365,3 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
docker-e2e-cleanup:
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [docker-e2e-tests]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

View File

View File

@@ -18,7 +18,7 @@ pytest -v --durations=10 \
--cov-append
# Run patched tests excluding lora kernels with coverage append
pytest --full-trace -vvv --durations=10 \
pytest -v --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \

View File

@@ -1,19 +0,0 @@
"""Modal app to run axolotl GPU cleanup"""
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cleanup():
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
@app.local_entrypoint()
def main():
cleanup.remote()

View File

@@ -1,6 +0,0 @@
#!/bin/bash
set -e
# cleanup old cache files for datasets processing and intermediate mappings
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;

View File

@@ -1,6 +1,69 @@
"""Modal app to run axolotl GPU tests"""
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function(

View File

@@ -1,66 +0,0 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -49,8 +49,7 @@ sections = [
("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum"),
("LLMCompressor", "llm_compressor")
("Spectrum", "spectrum")
]
for section_name, folder_name in sections:

View File

@@ -1,77 +0,0 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -150,9 +150,6 @@ extras_require = {
"vllm": [
"vllm==0.7.2",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.10.0.dev0"
__version__ = "0.9.1"

File diff suppressed because it is too large Load Diff

View File

@@ -1,21 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Init for axolotl.core.trainers.builders"""
# pylint: disable=unused-import
# flake8: noqa
from .causal import HFCausalTrainerBuilder
from .rl import HFRLTrainerBuilder

View File

@@ -1,331 +0,0 @@
"""Base class trainer / training args builder implementation"""
import abc
from typing import Any
from torch import Type
from transformers import TrainerCallback
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import GCCallback, SaveAxolotlConfigtoWandBCallback
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
PLUGIN_MANAGER = PluginManager.get_instance()
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
_train_dataset = None
_eval_dataset = None
_model_ref = None
_peft_config = None
def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@property
def peft_config(self):
return self._peft_config
@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config
@abc.abstractmethod
def build(self, total_num_steps):
pass
def get_common_training_args_kwargs(
self, total_num_steps: int | None = None
) -> dict[str, Any]:
"""Get common training arguments kwargs used across different trainer types."""
training_args_kwargs = {}
# Common parameters
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"dataloader_drop_last",
"remove_unused_columns",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
# Add Hub integration arguments if needed
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
# BF16/FP16 settings
if hasattr(self.cfg, "bf16") and self.cfg.bf16:
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16
elif hasattr(self.cfg, "bfloat16") and self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
if hasattr(self.cfg, "fp16"):
training_args_kwargs["fp16"] = (
getattr(self.cfg, "fp16", False)
and not getattr(self.cfg, "bf16", False)
) or False
# Set save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
# Handle safetensors
if self.cfg.save_safetensors is not None:
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
# Handle gradient checkpointing
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
# Common optimizer and LR scheduler settings
training_args_kwargs["optim"] = self.cfg.optimizer
if hasattr(self.cfg, "lr_scheduler") and self.cfg.lr_scheduler:
training_args_kwargs["lr_scheduler_type"] = self.cfg.lr_scheduler
else:
training_args_kwargs["lr_scheduler_type"] = "cosine"
if hasattr(self.cfg, "lr_scheduler_kwargs") and self.cfg.lr_scheduler_kwargs:
training_args_kwargs["lr_scheduler_kwargs"] = self.cfg.lr_scheduler_kwargs
else:
training_args_kwargs["lr_scheduler_kwargs"] = {}
# LoRA+ specific settings
if hasattr(self.cfg, "loraplus_lr_ratio"):
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
if hasattr(self.cfg, "loraplus_lr_embedding"):
training_args_kwargs["loraplus_lr_embedding"] = (
self.cfg.loraplus_lr_embedding
)
# Reporting tools
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
if report_to:
training_args_kwargs["report_to"] = report_to
# Basic training settings
if hasattr(self.cfg, "sequence_len"):
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["save_only_model"] = getattr(
self.cfg, "save_only_model", False
)
training_args_kwargs["save_total_limit"] = getattr(
self.cfg, "save_total_limit", 5
)
# Compute warmup steps
if hasattr(self.cfg, "warmup_steps") and self.cfg.warmup_steps is not None:
training_args_kwargs["warmup_steps"] = self.cfg.warmup_steps
elif (
total_num_steps
and hasattr(self.cfg, "warmup_ratio")
and self.cfg.warmup_ratio is not None
):
training_args_kwargs["warmup_steps"] = max(
int(self.cfg.warmup_ratio * total_num_steps), 0
)
elif total_num_steps:
training_args_kwargs["warmup_steps"] = min(int(0.03 * total_num_steps), 100)
return training_args_kwargs
def create_training_args(
self,
args_cls: Type[TrainingArguments],
total_num_steps: int | None = None,
**additional_kwargs,
) -> TrainingArguments:
"""Create training arguments with common logic."""
# Get common trainings args and update with trainer-specific args
training_args_kwargs = self.get_common_training_args_kwargs(total_num_steps)
training_args_kwargs.update(additional_kwargs)
# Create training args with pre- and post-creation hooks
training_args_kwargs = self.hook_pre_create_training_args(training_args_kwargs)
training_args = args_cls(**training_args_kwargs)
training_args = self.hook_post_create_training_args(training_args)
# Unset run_name so wandb sets up experiment names properly
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = None
return training_args
def create_trainer(
self, trainer_cls, training_args, trainer_args=None, trainer_kwargs=None
):
"""Create trainer with common logic."""
if trainer_args is None:
trainer_args = []
if trainer_kwargs is None:
trainer_kwargs = {}
# Create trainer with pre- and post- creation hooks
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_args,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# Add post-creation callbacks
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer
def get_callbacks(self) -> list[TrainerCallback]:
callbacks = []
callbacks.extend(
PLUGIN_MANAGER.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
"""Callbacks added after the trainer is created, usually because these need
access to the trainer.
"""
callbacks = []
if self.cfg.plugins:
callbacks.extend(
[
cb
for cb in PLUGIN_MANAGER.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer

View File

@@ -1,619 +0,0 @@
"""Causal trainer / training args builder implementation"""
import importlib
import inspect
import logging
import math
import os
import sys
from pathlib import Path
from typing import Type
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.builders.base import TrainerBuilderBase
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from axolotl.core.trainers.relora import ReLoRATrainer
from axolotl.core.trainers.trl import AxolotlPRMTrainer, AxolotlRewardTrainer
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators.batching import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mamba import MambaDataCollator
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""Build the HuggingFace training args / trainer for causal models and reward
modeling using TRL.
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback())
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
if self.cfg.plugins:
trainer_cls = PLUGIN_MANAGER.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
"""Build and return a causal trainer instance using the refactored base class."""
# Get trainer class
trainer_cls = self._get_trainer_cls()
# Prepare training arguments
training_args = self._prepare_training_args(total_num_steps)
# Prepare data collators
data_collator_kwargs = self._prepare_data_collator_kwargs()
# Prepare trainer kwargs
trainer_kwargs = self._prepare_trainer_kwargs(
trainer_cls=trainer_cls,
data_collator_kwargs=data_collator_kwargs,
training_args=training_args,
)
# Create the trainer
trainer = self.create_trainer(
trainer_cls=trainer_cls,
training_args=training_args,
trainer_kwargs={
"model": self.model,
"data_collator": self.build_collator(
training_args, **data_collator_kwargs
),
**trainer_kwargs,
},
)
# Handle DeepSpeed config for sample packing if needed
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer
def _prepare_training_args(self, total_num_steps):
"""Prepare and return training arguments."""
# Base training arguments
training_args_kwargs = self._get_base_training_args()
# Add feature configurations
self._add_feature_configs(training_args_kwargs)
# Handle optimizer configuration
self._configure_optimizer(training_args_kwargs)
# Create training args using the base class method
training_args_cls = self._get_training_args_cls()
return self.create_training_args(
args_cls=training_args_cls,
total_num_steps=total_num_steps,
**training_args_kwargs,
)
def _get_base_training_args(self):
"""Return the base training arguments."""
return {
"max_steps": self.cfg.max_steps if self.cfg.max_steps else -1,
"max_seq_length": self.cfg.sequence_len,
"per_device_train_batch_size": self.cfg.micro_batch_size,
"gradient_accumulation_steps": self.cfg.gradient_accumulation_steps,
"eval_accumulation_steps": self.cfg.gradient_accumulation_steps,
"num_train_epochs": self.cfg.num_epochs,
"learning_rate": self.cfg.learning_rate,
"output_dir": self.cfg.output_dir,
"weight_decay": (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
),
"model_type": self.cfg.model_config_type,
"pretraining": bool(self.cfg.pretraining_dataset),
"sequence_parallel_degree": self.cfg.sequence_parallel_degree,
"ring_attn_func": self.cfg.ring_attn_func,
"embedding_lr": self.cfg.embedding_lr,
"embedding_lr_scale": self.cfg.embedding_lr_scale,
"loraplus_lr_ratio": self.cfg.loraplus_lr_ratio,
"loraplus_lr_embedding": self.cfg.loraplus_lr_embedding,
"lr_groups": self.cfg.lr_groups,
}
def _add_feature_configs(self, training_args_kwargs):
"""Add various feature configurations."""
# Sample packing configurations
self._add_sample_packing_configs(training_args_kwargs)
# Batch size configurations
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
if self.cfg.auto_find_batch_size is not None:
training_args_kwargs["auto_find_batch_size"] = self.cfg.auto_find_batch_size
# Advanced training techniques (ReLoRA & Lisa)
self._add_advanced_training_configs(training_args_kwargs)
# Model-specific configurations
self._add_model_specific_configs(training_args_kwargs)
def _add_sample_packing_configs(self, training_args_kwargs):
"""Add sample packing configurations if applicable."""
if hasattr(self.cfg, "sample_packing") and self.cfg.sample_packing:
training_args_kwargs.update(
{
"sample_packing": bool(self.cfg.sample_packing),
"multipack_real_batches": not self.cfg.flash_attention
or self.cfg.multipack_real_batches,
"eval_sample_packing": bool(self.cfg.eval_sample_packing),
}
)
if self.cfg.sample_packing_bin_size is not None:
training_args_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_args_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_args_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
def _add_advanced_training_configs(self, training_args_kwargs):
"""Add advanced training techniques configurations (ReLoRA & Lisa)."""
# ReLoRA configurations
if self.cfg.relora_steps:
training_args_kwargs.update(
{
"relora_steps": self.cfg.relora_steps,
"relora_warmup_steps": self.cfg.relora_warmup_steps,
}
)
if self.cfg.relora_anneal_steps:
training_args_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_args_kwargs["relora_prune_ratio"] = self.cfg.relora_prune_ratio
# Lisa configurations
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_args_kwargs.update(
{
"lisa_n_layers": self.cfg.lisa_n_layers,
"lisa_step_interval": self.cfg.lisa_step_interval,
"lisa_layers_attribute": self.cfg.lisa_layers_attribute,
}
)
def _add_model_specific_configs(self, training_args_kwargs):
"""Add model-specific configurations."""
# Chat template
if self.cfg.chat_template:
training_args_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
# NEFTune
if self.cfg.neftune_noise_alpha is not None:
training_args_kwargs["neftune_noise_alpha"] = self.cfg.neftune_noise_alpha
# Knowledge distillation configurations
if self.cfg.kd_ce_alpha is not None:
training_args_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_args_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_args_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_args_kwargs["kd_zscore_base_temp"] = self.cfg.kd_zscore_base_temp
if self.cfg.kd_top_k_before_softmax is not None:
training_args_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
# Image configurations
if self.cfg.image_size:
training_args_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_args_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
# Accelerator configuration
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
def _configure_optimizer(self, training_args_kwargs):
"""Configure optimizer settings."""
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Use custom optimizer implementation
self._configure_custom_optimizer(training_args_kwargs)
else:
# Use transformers' optimizer
training_args_kwargs["optim"] = self.cfg.optimizer
self._add_optimizer_args(training_args_kwargs)
# Handle optimizer targeting specific modules
if self.cfg.optim_target_modules:
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
# Special case for anyprecision optimizer
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
def _configure_custom_optimizer(self, training_args_kwargs):
"""Configure custom optimizer settings."""
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_args_kwargs.get("learning_rate"),
"weight_decay": training_args_kwargs.get("weight_decay"),
}
# Add Adam-specific kwargs if available
adam_kwargs = self._get_adam_kwargs(training_args_kwargs)
# Get optimizer class and update kwargs based on optimizer type
optimizer_cls = self._get_optimizer_class(
training_args_kwargs, optimizer_kwargs, adam_kwargs
)
# Add any additional optimizer args from config
self._update_optimizer_kwargs_from_config(optimizer_kwargs)
training_args_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
def _get_adam_kwargs(self, training_args_kwargs):
"""Get Adam-specific kwargs if available."""
adam_kwargs = {}
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
"adam_beta2"
):
adam_kwargs["betas"] = (
training_args_kwargs.get("adam_beta1"),
training_args_kwargs.get("adam_beta2"),
)
if training_args_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
return adam_kwargs
def _get_optimizer_class(self, training_args_kwargs, optimizer_kwargs, adam_kwargs):
"""Get optimizer class based on configuration."""
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import MuonOptimizerFactory # pylint: disable=no-name-in-module
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
beta3 = training_args_kwargs.get("adam_beta2", 0.9999)
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
else:
# Default case or unsupported optimizer
optimizer_cls = None
return optimizer_cls
def _update_optimizer_kwargs_from_config(self, optimizer_kwargs):
"""Update optimizer kwargs from config."""
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
def _add_optimizer_args(self, training_args_kwargs):
"""Add optimizer arguments if available."""
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_args_kwargs["optim_args"] = optim_args
def _get_training_args_cls(self):
"""Get the appropriate training arguments class."""
if self.cfg.reward_model:
return AxolotlRewardConfig
if self.cfg.process_reward_model:
return AxolotlPRMConfig
return AxolotlTrainingArguments
def _prepare_data_collator_kwargs(self):
"""Prepare data collator kwargs."""
data_collator_kwargs = {"padding": True} # True/"longest" is the default
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
return data_collator_kwargs
def _prepare_trainer_kwargs(self, trainer_cls, data_collator_kwargs, training_args):
"""Prepare trainer kwargs."""
trainer_kwargs = {}
# Handle special data collators for evaluation
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
# Add bench data collator if needed
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
# Add tokenizer or processing class
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
# Add dataset tags if available
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
return trainer_kwargs
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
use_batch_sampler_collator = False
if is_eval is False and training_args.sample_packing:
use_batch_sampler_collator = True
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True
collator: Type[
V2BatchSamplerDataCollatorForSeq2Seq
| BatchSamplerDataCollatorForSeq2Seq
| DataCollatorForSeq2Seq
| DataCollatorWithFlattening
| RewardDataCollatorWithPadding
]
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.flex_attention:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
return collator(
*collator_args,
**kwargs,
)

View File

@@ -1,367 +0,0 @@
"""RL trainer / training args builder implementation"""
import inspect
from pathlib import Path
from axolotl.core.trainers.builders.base import TrainerBuilderBase
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
from axolotl.utils.models import ensure_dtype
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build_training_arguments(self, total_num_steps):
training_args_kwargs = {}
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"dataloader_num_workers",
"dataloader_pin_memory",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors is not None:
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.eval_dataset:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["eval_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_args_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
# set save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
max_steps = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=max_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
warmup_steps=self.cfg.warmup_steps,
logging_first_step=True,
logging_steps=1,
optim=self.cfg.optimizer,
save_total_limit=self.cfg.save_total_limit or 5,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args
def build(self, total_num_steps):
"""Build and return an RL trainer instance"""
# Prepare RL-specific training args kwargs
training_args_kwargs = {
"per_device_train_batch_size": self.cfg.micro_batch_size,
"max_steps": self.cfg.max_steps or total_num_steps or -1,
"gradient_accumulation_steps": self.cfg.gradient_accumulation_steps,
"learning_rate": self.cfg.learning_rate,
"warmup_steps": self.cfg.warmup_steps,
"logging_first_step": True,
"logging_steps": 1,
"output_dir": self.cfg.output_dir,
"num_train_epochs": self.cfg.num_epochs,
}
# Handle dataset processes
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# Handle beta/alpha parameters for different RL algorithms
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
# Determine training args class and add RL-specific parameters
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else: # Default to DPO
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
# Remove any blocklisted arguments
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
# Create training args using the base class method
training_args = self.create_training_args(
args_cls=training_args_cls,
total_num_steps=total_num_steps,
**training_args_kwargs,
)
# Prepare trainer kwargs
trainer_kwargs = {}
if self.cfg.rl == "ipo" and self.cfg.dpo_label_smoothing:
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
# Determine trainer class and arguments
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_args = [self.model]
trainer_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_args = [self.model]
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_args = [self.model]
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
# Add tokenizer or processing class
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
# Add dataset tags if available
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# Create the trainer
trainer = self.create_trainer(
trainer_cls=trainer_cls,
training_args=training_args,
trainer_args=trainer_args,
trainer_kwargs=trainer_kwargs,
)
# Handle FSDP specific settings
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if (
self.cfg.rl in ["dpo", "ipo"]
and hasattr(trainer, "ref_model")
and trainer.ref_model
):
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
return trainer

View File

@@ -1,108 +0,0 @@
# LLMCompressor Integration
Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).
This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.
It uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.
---
## Requirements
- Axolotl with `llmcompressor` extras:
```bash
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`
This will install all necessary dependencies to fine-tune sparsified models using the integration.
---
## Usage
To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:
```yaml
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true
# ... (other training arguments)
```
This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.
Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -1,5 +0,0 @@
"""Integration entry point for the LLMCompressor plugin."""
from .plugin import LLMCompressorPlugin
__all__ = ["LLMCompressorPlugin"]

View File

@@ -1,40 +0,0 @@
"""
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]

View File

@@ -1,171 +0,0 @@
"""
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from torch.nn import Module
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the Sparse Finetuning callback handler.
Args:
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.original_compute_loss = trainer.compute_loss
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
active_session().finalize()
self.trainer.compute_loss_func = self.original_compute_loss
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the path to the plugin's argument definition.
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(
compute_loss_func: Callable[Concatenate[Module, P], R],
) -> Callable[Concatenate[Module, P], R]:
"""
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (Callable): Original loss computation function.
Returns:
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
@wraps(compute_loss_func)
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(model, *args, **kwargs)
if active_session().lifecycle.initialized_ and model.training:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -1,40 +0,0 @@
"""Utilities for llmcompressor integration with axolotl."""
from typing import Union
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from transformers import PreTrainedModel, Trainer
def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
Synchronize processes, apply compression hooks, and save the model.
Args:
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
# Only the main process writes the files
if not trainer.accelerator.is_main_process:
return
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -26,7 +26,7 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainers.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
@@ -294,23 +294,8 @@ def save_trained_model(
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""

View File

@@ -46,11 +46,11 @@ from axolotl.utils.distributed import (
from axolotl.utils.schemas.config import AxolotlInputConfig
if TYPE_CHECKING:
from axolotl.core.training_args import AxolotlTrainingArguments
from axolotl.core.trainer_builder import AxolotlTrainingArguments
IGNORE_INDEX = -100
LOG = logging.getLogger(__name__)
LOG = logging.getLogger("axolotl.callbacks")
class EvalFirstStepCallback(

View File

@@ -141,22 +141,6 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config

View File

@@ -6,7 +6,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
from multiprocessing import cpu_count
from typing import Iterable, Union
import numba
@@ -78,11 +78,15 @@ def pack_group(
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sequence_lengths):
global_idx = seq_id + group_offset
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True
@@ -126,7 +130,6 @@ def pack_parallel(
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
mp_start_method: str | None = "spawn",
):
"""
Pack sequences into bins using parallel processing
@@ -138,9 +141,7 @@ def pack_parallel(
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
'spawn' is often safer with Numba/PyTorch.
Set to None to use system default.
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
@@ -157,33 +158,9 @@ def pack_parallel(
# Process groups in parallel
all_bins = []
mp_ctx = None
if mp_start_method:
try:
mp_ctx = get_context(mp_start_method)
except ValueError:
LOG.warning(
f"Failed to get multiprocessing context '{mp_start_method}'. "
f"Falling back to default. Available: {get_context().get_all_start_methods()}"
)
mp_ctx = (
None # Fallback to default context if specified one is not available
)
if num_processes == 1:
LOG.debug("Using single process for pack_parallel, running sequentially.")
for task_args in tasks:
group_bins = _process_group(task_args)
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
else:
# Use ProcessPoolExecutor only if num_processes > 1
# Pass mp_context if available
with ProcessPoolExecutor(
max_workers=num_processes, mp_context=mp_ctx
) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
return all_bins

View File

@@ -16,7 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainers.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
@@ -633,7 +633,8 @@ def setup_trainer(
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance configured based on the provided parameters.
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if (
cfg.torch_compile

View File

@@ -1,8 +1,10 @@
"""Unit tests for axolotl.core.trainers.builders"""
"""
unit tests for axolotl.core.trainer_builder
"""
import pytest
from axolotl.core.trainers.builders import HFRLTrainerBuilder
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -51,7 +53,9 @@ def fixture_model(cfg, tokenizer):
class TestHFRLTrainerBuilder:
"""Test case class for RL trainer builder"""
"""
TestCase class for DPO trainer builder
"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)

View File

@@ -1,111 +0,0 @@
"""
E2E smoke tests for LLMCompressorPlugin integration
"""
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import (
check_model_output_exists,
require_llmcompressor,
require_torch_2_4_1,
)
MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
]
@pytest.mark.parametrize(
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
)
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_llmcompressor
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
from llmcompressor import active_session
# core cfg
cfg = DictDefault(
{
"base_model": base_model,
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {
"recipe": {
"finetuning_stage": {
"finetuning_modifiers": {
"ConstantPruningModifier": {
"targets": [
"re:.*q_proj.weight",
"re:.*k_proj.weight",
"re:.*v_proj.weight",
"re:.*o_proj.weight",
"re:.*gate_proj.weight",
"re:.*up_proj.weight",
"re:.*down_proj.weight",
],
"start": 0,
},
},
},
},
"save_compressed": save_compressed,
},
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
try:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
finally:
active_session().reset()
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
if save_compressed:
assert (Path(temp_dir) / "recipe.yaml").exists()
from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig
compressor = ModelCompressor.from_pretrained(temp_dir)
assert compressor is not None
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)

View File

@@ -57,9 +57,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)
@@ -105,9 +105,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)

View File

@@ -57,9 +57,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
@@ -99,9 +99,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)

View File

@@ -54,9 +54,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
@@ -93,9 +93,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)

View File

@@ -56,9 +56,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"bf16": "auto",
}
)
@@ -108,9 +108,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"bf16": "auto",
}
)

View File

@@ -1,21 +1,21 @@
"""Test module to import various submodules that have historically broken due to
dependency issues.
"""
test module to import various submodules that have historically broken due to dependency issues
"""
import unittest
class TestImports(unittest.TestCase):
"""Test class to import various submodules that have historically broken due to
dependency issues.
"""
Test class to import various submodules that have historically broken due to dependency issues
"""
def test_import_causal_trainer(self):
from axolotl.core.trainers.builders import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
HFCausalTrainerBuilder,
)
def test_import_rl_trainer(self):
from axolotl.core.trainers.builders import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
HFRLTrainerBuilder,
)

View File

@@ -105,25 +105,7 @@ def require_vllm(test_case):
return False
return unittest.skipUnless(
is_vllm_installed(), "test requires vllm to be installed"
)(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
is_vllm_installed(), "test requires a vllm to be installed"
)(test_case)

View File

@@ -106,4 +106,3 @@ class TestBatchedSamplerPacking:
original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs)
assert len(batch_idxs) == len(set(batch_idxs))