Compare commits

...

17 Commits

Author SHA1 Message Date
Rahul Tuli
3a8b637598 Tests, Style, Updates 2025-04-30 17:21:52 -04:00
Rahul Tuli
12cd09e6f5 Rebase and updates! 2025-04-30 17:21:52 -04:00
Rahul Tuli
fe82f62248 Add: llm_compressor integration documentation 2025-04-30 17:21:52 -04:00
Rahul Tuli
db31d7ad22 Move: LLMCompressorPlugin into it's own submodule 2025-04-30 17:21:52 -04:00
Rahul Tuli
eb7f2aa4b9 Update model config 2025-04-30 17:21:51 -04:00
Rahul Tuli
f80e36ddd2 Use: absolute import 2025-04-30 17:21:51 -04:00
Rahul Tuli
412d2ec6d0 Rename: sft.yaml to sparse-finetuning.yaml 2025-04-30 17:21:51 -04:00
Rahul Tuli
50fc5e6984 Add: llcompressor installable 2025-04-30 17:21:51 -04:00
Rahul Tuli
83a88b745f Address review comments from @markurtz 2025-04-30 17:21:51 -04:00
Rahul Tuli
8855bb115f Apply suggestions from @markurtz
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
2025-04-30 17:21:51 -04:00
Rahul Tuli
ef9543b371 Update llmcompressor version to latest 2025-04-30 17:21:51 -04:00
Rahul Tuli
25e701e885 Revert: TODO's 2025-04-30 17:21:50 -04:00
Rahul Tuli
891a21e599 Use: warning over warn 2025-04-30 17:21:50 -04:00
Rahul Tuli
8beb2f27ad pre commit hooks 2025-04-30 17:21:50 -04:00
Rahul Tuli
56ba66b60f Add:llmcompressor instalable 2025-04-30 17:21:50 -04:00
Rahul Tuli
13d4b865d6 Update: review comments! 2025-04-30 17:21:50 -04:00
Rahul Tuli
3da866b2b9 Add: SFTPlugin with llmcompressor 2025-04-30 17:21:50 -04:00
10 changed files with 539 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -149,6 +149,9 @@ extras_require = {
"vllm": [ "vllm": [
"vllm==0.7.2", "vllm==0.7.2",
], ],
"llmcompressor": [
"llmcompressor==0.5.1",
],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -0,0 +1,70 @@
# LLMCompressor Integration
Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).
This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.
It uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.
---
## Requirements
- Axolotl with `llmcompressor` extras:
```bash
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`
This will install all necessary dependencies to fine-tune sparsified models using the integration.
---
## Usage
To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:
```yaml
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
# ... (other training arguments)
```
This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.
Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Or downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
---
## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
👉 [https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -0,0 +1,5 @@
"""Integration entry point for the LLMCompressor plugin."""
from .plugin import LLMCompressorPlugin
__all__ = ["LLMCompressorPlugin"]

View File

@@ -0,0 +1,40 @@
"""
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]

View File

@@ -0,0 +1,171 @@
"""
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from torch.nn import Module
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the Sparse Finetuning callback handler.
Args:
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.original_compute_loss = trainer.compute_loss
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
active_session().finalize()
self.trainer.compute_loss_func = self.original_compute_loss
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the path to the plugin's argument definition.
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(
compute_loss_func: Callable[Concatenate[Module, P], R],
) -> Callable[Concatenate[Module, P], R]:
"""
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (Callable): Original loss computation function.
Returns:
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
@wraps(compute_loss_func)
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(model, *args, **kwargs)
if active_session().lifecycle.initialized_ and model.training:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -0,0 +1,40 @@
"""Utilities for llmcompressor integration with axolotl."""
from typing import Union
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from transformers import PreTrainedModel, Trainer
def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
Synchronize processes, apply compression hooks, and save the model.
Args:
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
# Only the main process writes the files
if not trainer.accelerator.is_main_process:
return
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -288,6 +288,19 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
@@ -296,6 +309,7 @@ def save_trained_model(
trainer.model.save_pretrained( trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
) )
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

View File

@@ -141,6 +141,22 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config") hasattr(model_config, "quantization_config")
and model_config.quantization_config and model_config.quantization_config
) )
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = ( quant_config_method_is_gptq = (
quant_config_exists quant_config_exists
and "quant_method" in model_config.quantization_config and "quant_method" in model_config.quantization_config

View File

@@ -0,0 +1,103 @@
"""
E2E smoke tests for LLMCompressorPlugin integration
"""
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
]
@pytest.mark.parametrize(
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
)
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
# core cfg
cfg = DictDefault(
{
"base_model": base_model,
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {
"recipe": {
"finetuning_stage": {
"finetuning_modifiers": {
"ConstantPruningModifier": {
"targets": [
"re:.*q_proj.weight",
"re:.*k_proj.weight",
"re:.*v_proj.weight",
"re:.*o_proj.weight",
"re:.*gate_proj.weight",
"re:.*up_proj.weight",
"re:.*down_proj.weight",
],
"start": 0,
},
},
},
},
"save_compressed": save_compressed,
},
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
# recipe.yaml should exist
assert (Path(temp_dir) / "recipe.yaml").exists()
# sparsity config exists if save_compressed
if save_compressed:
from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig
compressor = ModelCompressor.from_pretrained(temp_dir)
assert compressor is not None
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)