diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e633215e4..c2c085fa0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -261,6 +261,18 @@ jobs: fail-fast: false matrix: include: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.6.0 + num_gpus: 1 + axolotl_extras: llmcompressor + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.1 + num_gpus: 1 + axolotl_extras: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" diff --git a/docs/custom_integrations.qmd b/docs/custom_integrations.qmd index cb4aef9ca..023f09732 100644 --- a/docs/custom_integrations.qmd +++ b/docs/custom_integrations.qmd @@ -49,7 +49,8 @@ sections = [ ("Knowledge Distillation (KD)", "kd"), ("Liger Kernels", "liger"), ("Language Model Evaluation Harness (LM Eval)", "lm_eval"), - ("Spectrum", "spectrum") + ("Spectrum", "spectrum"), + ("LLMCompressor", "llm_compressor") ] for section_name, folder_name in sections: diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml new file mode 100644 index 000000000..1bbb88028 --- /dev/null +++ b/examples/llama-3/sparse-finetuning.yaml @@ -0,0 +1,77 @@ +base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4 + +plugins: + - axolotl.integrations.llm_compressor.LLMCompressorPlugin + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true +eval_sample_packing: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: <|end_of_text|> + +llmcompressor: + recipe: + finetuning_stage: + finetuning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 + save_compressed: true diff --git a/setup.py b/setup.py index 5b66a2ea7..d6009b8d5 100644 --- a/setup.py +++ b/setup.py @@ -149,6 +149,9 @@ extras_require = { "vllm": [ "vllm==0.7.2", ], + "llmcompressor": [ + "llmcompressor==0.5.1", + ], } install_requires, dependency_links, extras_require_build = parse_requirements( diff --git a/src/axolotl/integrations/llm_compressor/README.md b/src/axolotl/integrations/llm_compressor/README.md new file mode 100644 index 000000000..16eff804d --- /dev/null +++ b/src/axolotl/integrations/llm_compressor/README.md @@ -0,0 +1,108 @@ +# LLMCompressor Integration + +Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor). + +This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale. + +It uses Axolotl’s plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training. + +--- + +## Requirements + +- Axolotl with `llmcompressor` extras: + + ```bash + pip install "axolotl[llmcompressor]" + ``` + +- Requires `llmcompressor >= 0.5.1` + +This will install all necessary dependencies to fine-tune sparsified models using the integration. + +--- + +## Usage + +To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config: + +```yaml +plugins: + - axolotl.integrations.llm_compressor.LLMCompressorPlugin + +llmcompressor: + recipe: + finetuning_stage: + finetuning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 + save_compressed: true +# ... (other training arguments) +``` + +This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**. + +Pre-sparsified checkpoints can be: +- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor) +- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic) +- Any custom LLM with compatible sparsity patterns that you've created yourself + +To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation: +[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md) + +### Storage Optimization with save_compressed + +Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which: +- Reduces disk space usage by approximately 40% +- Maintains compatibility with vLLM for accelerated inference +- Maintains compatibility with llmcompressor for further optimization (example: quantization) + +This option is highly recommended when working with sparse models to maximize the benefits of model compression. + +### Example Config + +See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example. + +--- + +## Inference with vLLM + +After fine-tuning your sparse model, you can leverage vLLM for efficient inference. +You can also use LLMCompressor to apply additional quantization to your fine-tuned +sparse model before inference for even greater performance benefits.: + +```python +from vllm import LLM, SamplingParams + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +llm = LLM("path/to/your/sparse/model") +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` + +For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/). + +## Learn More + +For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository: + +[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) diff --git a/src/axolotl/integrations/llm_compressor/__init__.py b/src/axolotl/integrations/llm_compressor/__init__.py new file mode 100644 index 000000000..fe799d3c0 --- /dev/null +++ b/src/axolotl/integrations/llm_compressor/__init__.py @@ -0,0 +1,5 @@ +"""Integration entry point for the LLMCompressor plugin.""" + +from .plugin import LLMCompressorPlugin + +__all__ = ["LLMCompressorPlugin"] diff --git a/src/axolotl/integrations/llm_compressor/args.py b/src/axolotl/integrations/llm_compressor/args.py new file mode 100644 index 000000000..4c0e4cac3 --- /dev/null +++ b/src/axolotl/integrations/llm_compressor/args.py @@ -0,0 +1,40 @@ +""" +LLMCompressor and Sparse Finetuning config models. +""" + +from typing import Any + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + + +class CompressionArgs(BaseModel): + """Sparse Finetuning config for LLMCompressor.""" + + # Typing for recipe is set to Any due to: + # https://github.com/vllm-project/llm-compressor/issues/1319 + recipe: Annotated[ + Any, + Field( + description="The recipe containing the compression algorithms and hyperparameters to apply." + ), + ] + + save_compressed: Annotated[ + bool, + Field( + default=False, + description="Whether to save the compressed model after training.", + ), + ] + + +class LLMCompressorArgs(BaseModel): + """LLMCompressor configuration BaseModel.""" + + llmcompressor: Annotated[ + CompressionArgs, + Field( + description="Arguments enabling compression pathways through the LLM Compressor plugins" + ), + ] diff --git a/src/axolotl/integrations/llm_compressor/plugin.py b/src/axolotl/integrations/llm_compressor/plugin.py new file mode 100644 index 000000000..d986d51f4 --- /dev/null +++ b/src/axolotl/integrations/llm_compressor/plugin.py @@ -0,0 +1,171 @@ +""" +Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks +by maintaining masks for zero weights during training. +""" + +import logging +from functools import wraps +from typing import Any, Callable, Concatenate, ParamSpec, TypeVar + +from llmcompressor import active_session, create_session +from llmcompressor.core import callbacks as session_callbacks +from llmcompressor.recipe import Recipe +from torch.nn import Module +from transformers.trainer import Trainer +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from axolotl.integrations.base import BasePlugin + +P = ParamSpec("P") # Params for generic function signatures +R = TypeVar("R") # Return type for generic function signatures + +LOG = logging.getLogger("axolotl.integrations.llm_compressor") + + +class LLMCompressorCallbackHandler(TrainerCallback): + """ + Trainer callback for Sparse Finetuning. + Maintains sparsity patterns during training by applying masks after optimization steps, + ensuring zero-weight updates are canceled out. + """ + + def __init__(self, trainer: Trainer, recipe: Any): + """ + Initialize the Sparse Finetuning callback handler. + + Args: + trainer (Trainer): Huggingface Trainer instance. + recipe (Recipe | dict): Sparse finetuning recipe to apply. + """ + super().__init__() + self.trainer = trainer + self.recipe = ( + Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe + ) + self.original_compute_loss = trainer.compute_loss + self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) + create_session() + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the beginning of training. Initializes the compression session. + + Args: + args (TrainingArguments): Training arguments. + state (TrainerState): Trainer state. + control (TrainerControl): Trainer control. + """ + super().on_train_begin(args, state, control, **kwargs) + self.trainer.accelerator.wait_for_everyone() + active_session().initialize( + model=self.trainer.model, + optimizer=self.trainer.optimizer, + start=state.epoch, + recipe=self.recipe, + ) + self.trainer.accelerator.wait_for_everyone() + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the beginning of a training step. Triggers batch_start callback. + """ + super().on_step_begin(args, state, control, **kwargs) + session_callbacks.batch_start() + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the end of a training step. Triggers optimizer and batch_end callbacks. + """ + super().on_step_end(args, state, control, **kwargs) + session_callbacks.optim_pre_step() + session_callbacks.optim_post_step() + session_callbacks.batch_end() + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the end of training. Finalizes the compression session. + """ + super().on_train_end(args, state, control, **kwargs) + active_session().finalize() + self.trainer.compute_loss_func = self.original_compute_loss + + +class LLMCompressorPlugin(BasePlugin): + """ + Sparse Finetuning plugin for Axolotl integration. + """ + + def get_input_args(self) -> str: + """ + Returns the path to the plugin's argument definition. + + Returns: + str: Dotted path to the LLMCompressorArgs class. + """ + return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs" + + def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: + """ + Adds Sparse Finetuning callback to the Trainer instance. + + Args: + cfg (Any): Configuration object containing the sparse recipe. + trainer (Trainer): Huggingface Trainer instance. + + Returns: + list: List containing the configured callback instances. + """ + LOG.info("Adding Sparse Finetuning callback to the trainer") + callback = LLMCompressorCallbackHandler( + trainer=trainer, + recipe=cfg.llmcompressor.recipe, + ) + return [callback] + + +def compute_loss_wrapper( + compute_loss_func: Callable[Concatenate[Module, P], R], +) -> Callable[Concatenate[Module, P], R]: + """ + Wraps the loss computation function to trigger the loss_calculated callback. + + Args: + compute_loss_func (Callable): Original loss computation function. + + Returns: + Callable: Wrapped function that also invokes the loss_calculated callback. + """ + + @wraps(compute_loss_func) + def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R: + loss = compute_loss_func(model, *args, **kwargs) + if active_session().lifecycle.initialized_ and model.training: + session_callbacks.loss_calculated(loss=loss) + return loss + + return compute_and_notify diff --git a/src/axolotl/integrations/llm_compressor/utils.py b/src/axolotl/integrations/llm_compressor/utils.py new file mode 100644 index 000000000..f04454e5b --- /dev/null +++ b/src/axolotl/integrations/llm_compressor/utils.py @@ -0,0 +1,40 @@ +"""Utilities for llmcompressor integration with axolotl.""" + +from typing import Union + +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + modify_save_pretrained, +) +from transformers import PreTrainedModel, Trainer + + +def save_compressed_model( + model: PreTrainedModel, + output_dir: Union[str, bytes], + trainer: Trainer, + safe_serialization: bool = False, + save_compressed: bool = False, +) -> None: + """ + Synchronize processes, apply compression hooks, and save the model. + + Args: + model (PreTrainedModel): The model to be saved. + output_dir (str or bytes): Path where the model files will be written. + trainer (Trainer): Hugging Face Trainer for process synchronization. + safe_serialization (bool): Use safe serialization if True. + save_compressed (bool): Write compressed tensors if True. + """ + trainer.accelerator.wait_for_everyone() + + # Only the main process writes the files + if not trainer.accelerator.is_main_process: + return + + modify_save_pretrained(model) + model.save_pretrained( + output_dir, + safe_serialization=safe_serialization, + save_compressed=save_compressed, + skip_sparsity_compression_stats=not save_compressed, + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 7896239de..808d3af64 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -296,8 +296,23 @@ def save_trained_model( trainer.model.save_pretrained( cfg.output_dir, safe_serialization=safe_serialization ) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: + # TODO: add integration support so this can be implemented completely within the plugin + from axolotl.integrations.llm_compressor.utils import ( + save_compressed_model, + ) + + save_compressed_model( + model=model, + output_dir=cfg.output_dir, + trainer=trainer, + safe_serialization=safe_serialization, + save_compressed=cfg.llmcompressor.save_compressed, + ) + def create_model_card(cfg: DictDefault, trainer: Trainer): """ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e88de1bad..ba71ea459 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -141,6 +141,22 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): hasattr(model_config, "quantization_config") and model_config.quantization_config ) + + # Detect compressed-tensors config + is_compressed_tensors_config = ( + quant_config_exists + and model_config.quantization_config.get("quant_method") == "compressed-tensors" + ) + + if is_compressed_tensors_config: + if model_config.quantization_config.get("config_groups"): + LOG.warning( + "Found `config_groups` in a compressed-tensors config. " + "QAT integration with llmcompressor is not tested." + ) + # Skip further quant checks for compressed-tensors + return + quant_config_method_is_gptq = ( quant_config_exists and "quant_method" in model_config.quantization_config diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py new file mode 100644 index 000000000..20bf821bf --- /dev/null +++ b/tests/e2e/integrations/test_llm_compressor.py @@ -0,0 +1,111 @@ +""" +E2E smoke tests for LLMCompressorPlugin integration +""" + +from pathlib import Path + +import pytest + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import ( + check_model_output_exists, + require_llmcompressor, + require_torch_2_4_1, +) + +MODELS = [ + "nm-testing/llama2.c-stories42M-pruned2.4-compressed", + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed", +] + + +@pytest.mark.parametrize( + "base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"] +) +@pytest.mark.parametrize( + "save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"] +) +class TestLLMCompressorIntegration: + """ + e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin + """ + + @require_llmcompressor + @require_torch_2_4_1 + def test_llmcompressor_plugin( + self, temp_dir, base_model: str, save_compressed: bool + ): + from llmcompressor import active_session + + # core cfg + cfg = DictDefault( + { + "base_model": base_model, + "plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"], + "sequence_len": 1024, + "val_set_size": 0.05, + "special_tokens": {"pad_token": "<|endoftext|>"}, + "datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 1e-5, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "max_steps": 5, + "llmcompressor": { + "recipe": { + "finetuning_stage": { + "finetuning_modifiers": { + "ConstantPruningModifier": { + "targets": [ + "re:.*q_proj.weight", + "re:.*k_proj.weight", + "re:.*v_proj.weight", + "re:.*o_proj.weight", + "re:.*gate_proj.weight", + "re:.*up_proj.weight", + "re:.*down_proj.weight", + ], + "start": 0, + }, + }, + }, + }, + "save_compressed": save_compressed, + }, + } + ) + + prepare_plugins(cfg) + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + try: + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + _check_llmcompressor_model_outputs(temp_dir, save_compressed) + finally: + active_session().reset() + + +def _check_llmcompressor_model_outputs(temp_dir, save_compressed): + if save_compressed: + assert (Path(temp_dir) / "recipe.yaml").exists() + + from compressed_tensors import ModelCompressor + from compressed_tensors.config import Sparse24BitMaskConfig + + compressor = ModelCompressor.from_pretrained(temp_dir) + assert compressor is not None + assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 2fbf333c4..61df1d8fe 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -105,7 +105,25 @@ def require_vllm(test_case): return False return unittest.skipUnless( - is_vllm_installed(), "test requires a vllm to be installed" + is_vllm_installed(), "test requires vllm to be installed" + )(test_case) + + +def require_llmcompressor(test_case): + """ + Decorator marking a test that requires a llmcompressor to be installed + """ + + def is_llmcompressor_installed(): + try: + import llmcompressor # pylint: disable=unused-import # noqa: F401 + + return True + except ImportError: + return False + + return unittest.skipUnless( + is_llmcompressor_installed(), "test requires llmcompressor to be installed" )(test_case)