diff --git a/setup.py b/setup.py index bee138da2..c84178597 100644 --- a/setup.py +++ b/setup.py @@ -149,7 +149,6 @@ extras_require = { "vllm": [ "vllm==0.7.2", ], - # PENDING: https://github.com/vllm-project/llm-compressor/pull/1352 "llmcompressor": [ "llmcompressor==0.5.1", ], diff --git a/src/axolotl/integrations/llm_compressor/args.py b/src/axolotl/integrations/llm_compressor/args.py index 5ab62325f..4c0e4cac3 100644 --- a/src/axolotl/integrations/llm_compressor/args.py +++ b/src/axolotl/integrations/llm_compressor/args.py @@ -4,7 +4,7 @@ LLMCompressor and Sparse Finetuning config models. from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -38,7 +38,3 @@ class LLMCompressorArgs(BaseModel): description="Arguments enabling compression pathways through the LLM Compressor plugins" ), ] - - model_config = ConfigDict( - validate_assignment=True, - ) diff --git a/src/axolotl/integrations/llm_compressor/plugin.py b/src/axolotl/integrations/llm_compressor/plugin.py index 45895e42d..d986d51f4 100644 --- a/src/axolotl/integrations/llm_compressor/plugin.py +++ b/src/axolotl/integrations/llm_compressor/plugin.py @@ -5,11 +5,12 @@ by maintaining masks for zero weights during training. import logging from functools import wraps -from typing import Any, Callable, ParamSpec, TypeVar +from typing import Any, Callable, Concatenate, ParamSpec, TypeVar from llmcompressor import active_session, create_session from llmcompressor.core import callbacks as session_callbacks from llmcompressor.recipe import Recipe +from torch.nn import Module from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState from transformers.training_args import TrainingArguments @@ -42,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback): self.recipe = ( Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe ) + self.original_compute_loss = trainer.compute_loss self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) create_session() @@ -110,6 +112,7 @@ class LLMCompressorCallbackHandler(TrainerCallback): """ super().on_train_end(args, state, control, **kwargs) active_session().finalize() + self.trainer.compute_loss_func = self.original_compute_loss class LLMCompressorPlugin(BasePlugin): @@ -145,7 +148,9 @@ class LLMCompressorPlugin(BasePlugin): return [callback] -def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]: +def compute_loss_wrapper( + compute_loss_func: Callable[Concatenate[Module, P], R], +) -> Callable[Concatenate[Module, P], R]: """ Wraps the loss computation function to trigger the loss_calculated callback. @@ -157,9 +162,9 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]: """ @wraps(compute_loss_func) - def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R: - loss = compute_loss_func(*args, **kwargs) - if active_session().lifecycle.initialized_: + def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R: + loss = compute_loss_func(model, *args, **kwargs) + if active_session().lifecycle.initialized_ and model.training: session_callbacks.loss_calculated(loss=loss) return loss diff --git a/src/axolotl/integrations/llm_compressor/utils.py b/src/axolotl/integrations/llm_compressor/utils.py index 945c0f3ac..f04454e5b 100644 --- a/src/axolotl/integrations/llm_compressor/utils.py +++ b/src/axolotl/integrations/llm_compressor/utils.py @@ -1,15 +1,40 @@ -from transformers import Trainer +"""Utilities for llmcompressor integration with axolotl.""" + +from typing import Union + +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + modify_save_pretrained, +) +from transformers import PreTrainedModel, Trainer + def save_compressed_model( - model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool -): - from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained + model: PreTrainedModel, + output_dir: Union[str, bytes], + trainer: Trainer, + safe_serialization: bool = False, + save_compressed: bool = False, +) -> None: + """ + Synchronize processes, apply compression hooks, and save the model. + + Args: + model (PreTrainedModel): The model to be saved. + output_dir (str or bytes): Path where the model files will be written. + trainer (Trainer): Hugging Face Trainer for process synchronization. + safe_serialization (bool): Use safe serialization if True. + save_compressed (bool): Write compressed tensors if True. + """ trainer.accelerator.wait_for_everyone() - if trainer.accelerator.is_main_process: - modify_save_pretrained(model) - model.save_pretrained( - output_dir, - safe_serialization=safe_serialization, - save_compressed=save_compressed, - skip_sparsity_compression_stats=not save_compressed, - ) \ No newline at end of file + + # Only the main process writes the files + if not trainer.accelerator.is_main_process: + return + + modify_save_pretrained(model) + model.save_pretrained( + output_dir, + safe_serialization=safe_serialization, + save_compressed=save_compressed, + skip_sparsity_compression_stats=not save_compressed, + ) diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py new file mode 100644 index 000000000..707beacbf --- /dev/null +++ b/tests/e2e/integrations/test_llm_compressor.py @@ -0,0 +1,103 @@ +""" +E2E smoke tests for LLMCompressorPlugin integration +""" + +from pathlib import Path + +import pytest + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, prepare_plugins +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1 + +MODELS = [ + "nm-testing/llama2.c-stories42M-pruned2.4-compressed", + "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed", +] + + +@pytest.mark.parametrize( + "base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"] +) +@pytest.mark.parametrize( + "save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"] +) +class TestLLMCompressorIntegration: + """ + e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin + """ + + @require_torch_2_4_1 + def test_llmcompressor_plugin( + self, temp_dir, base_model: str, save_compressed: bool + ): + # core cfg + cfg = DictDefault( + { + "base_model": base_model, + "plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"], + "sequence_len": 1024, + "val_set_size": 0.05, + "special_tokens": {"pad_token": "<|endoftext|>"}, + "datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 1e-5, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "max_steps": 5, + "llmcompressor": { + "recipe": { + "finetuning_stage": { + "finetuning_modifiers": { + "ConstantPruningModifier": { + "targets": [ + "re:.*q_proj.weight", + "re:.*k_proj.weight", + "re:.*v_proj.weight", + "re:.*o_proj.weight", + "re:.*gate_proj.weight", + "re:.*up_proj.weight", + "re:.*down_proj.weight", + ], + "start": 0, + }, + }, + }, + }, + "save_compressed": save_compressed, + }, + } + ) + + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + _check_llmcompressor_model_outputs(temp_dir, save_compressed) + + +def _check_llmcompressor_model_outputs(temp_dir, save_compressed): + + # recipe.yaml should exist + assert (Path(temp_dir) / "recipe.yaml").exists() + + # sparsity config exists if save_compressed + if save_compressed: + from compressed_tensors import ModelCompressor + from compressed_tensors.config import Sparse24BitMaskConfig + + compressor = ModelCompressor.from_pretrained(temp_dir) + assert compressor is not None + assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)