Tests, Style, Updates
This commit is contained in:
1
setup.py
1
setup.py
@@ -149,7 +149,6 @@ extras_require = {
|
|||||||
"vllm": [
|
"vllm": [
|
||||||
"vllm==0.7.2",
|
"vllm==0.7.2",
|
||||||
],
|
],
|
||||||
# PENDING: https://github.com/vllm-project/llm-compressor/pull/1352
|
|
||||||
"llmcompressor": [
|
"llmcompressor": [
|
||||||
"llmcompressor==0.5.1",
|
"llmcompressor==0.5.1",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ LLMCompressor and Sparse Finetuning config models.
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
@@ -38,7 +38,3 @@ class LLMCompressorArgs(BaseModel):
|
|||||||
description="Arguments enabling compression pathways through the LLM Compressor plugins"
|
description="Arguments enabling compression pathways through the LLM Compressor plugins"
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
|
||||||
validate_assignment=True,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ by maintaining masks for zero weights during training.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, ParamSpec, TypeVar
|
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
from llmcompressor import active_session, create_session
|
from llmcompressor import active_session, create_session
|
||||||
from llmcompressor.core import callbacks as session_callbacks
|
from llmcompressor.core import callbacks as session_callbacks
|
||||||
from llmcompressor.recipe import Recipe
|
from llmcompressor.recipe import Recipe
|
||||||
|
from torch.nn import Module
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||||
from transformers.training_args import TrainingArguments
|
from transformers.training_args import TrainingArguments
|
||||||
@@ -42,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
|||||||
self.recipe = (
|
self.recipe = (
|
||||||
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||||
)
|
)
|
||||||
|
self.original_compute_loss = trainer.compute_loss
|
||||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||||
create_session()
|
create_session()
|
||||||
|
|
||||||
@@ -110,6 +112,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
super().on_train_end(args, state, control, **kwargs)
|
super().on_train_end(args, state, control, **kwargs)
|
||||||
active_session().finalize()
|
active_session().finalize()
|
||||||
|
self.trainer.compute_loss_func = self.original_compute_loss
|
||||||
|
|
||||||
|
|
||||||
class LLMCompressorPlugin(BasePlugin):
|
class LLMCompressorPlugin(BasePlugin):
|
||||||
@@ -145,7 +148,9 @@ class LLMCompressorPlugin(BasePlugin):
|
|||||||
return [callback]
|
return [callback]
|
||||||
|
|
||||||
|
|
||||||
def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
|
def compute_loss_wrapper(
|
||||||
|
compute_loss_func: Callable[Concatenate[Module, P], R],
|
||||||
|
) -> Callable[Concatenate[Module, P], R]:
|
||||||
"""
|
"""
|
||||||
Wraps the loss computation function to trigger the loss_calculated callback.
|
Wraps the loss computation function to trigger the loss_calculated callback.
|
||||||
|
|
||||||
@@ -157,9 +162,9 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(compute_loss_func)
|
@wraps(compute_loss_func)
|
||||||
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
|
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
loss = compute_loss_func(*args, **kwargs)
|
loss = compute_loss_func(model, *args, **kwargs)
|
||||||
if active_session().lifecycle.initialized_:
|
if active_session().lifecycle.initialized_ and model.training:
|
||||||
session_callbacks.loss_calculated(loss=loss)
|
session_callbacks.loss_calculated(loss=loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,40 @@
|
|||||||
from transformers import Trainer
|
"""Utilities for llmcompressor integration with axolotl."""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
|
||||||
|
modify_save_pretrained,
|
||||||
|
)
|
||||||
|
from transformers import PreTrainedModel, Trainer
|
||||||
|
|
||||||
|
|
||||||
def save_compressed_model(
|
def save_compressed_model(
|
||||||
model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool
|
model: PreTrainedModel,
|
||||||
):
|
output_dir: Union[str, bytes],
|
||||||
from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained
|
trainer: Trainer,
|
||||||
|
safe_serialization: bool = False,
|
||||||
|
save_compressed: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Synchronize processes, apply compression hooks, and save the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (PreTrainedModel): The model to be saved.
|
||||||
|
output_dir (str or bytes): Path where the model files will be written.
|
||||||
|
trainer (Trainer): Hugging Face Trainer for process synchronization.
|
||||||
|
safe_serialization (bool): Use safe serialization if True.
|
||||||
|
save_compressed (bool): Write compressed tensors if True.
|
||||||
|
"""
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
if trainer.accelerator.is_main_process:
|
|
||||||
modify_save_pretrained(model)
|
# Only the main process writes the files
|
||||||
model.save_pretrained(
|
if not trainer.accelerator.is_main_process:
|
||||||
output_dir,
|
return
|
||||||
safe_serialization=safe_serialization,
|
|
||||||
save_compressed=save_compressed,
|
modify_save_pretrained(model)
|
||||||
skip_sparsity_compression_stats=not save_compressed,
|
model.save_pretrained(
|
||||||
)
|
output_dir,
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
save_compressed=save_compressed,
|
||||||
|
skip_sparsity_compression_stats=not save_compressed,
|
||||||
|
)
|
||||||
|
|||||||
103
tests/e2e/integrations/test_llm_compressor.py
Normal file
103
tests/e2e/integrations/test_llm_compressor.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke tests for LLMCompressorPlugin integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
|
||||||
|
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
|
||||||
|
)
|
||||||
|
class TestLLMCompressorIntegration:
|
||||||
|
"""
|
||||||
|
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
|
||||||
|
"""
|
||||||
|
|
||||||
|
@require_torch_2_4_1
|
||||||
|
def test_llmcompressor_plugin(
|
||||||
|
self, temp_dir, base_model: str, save_compressed: bool
|
||||||
|
):
|
||||||
|
# core cfg
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": base_model,
|
||||||
|
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
|
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"max_steps": 5,
|
||||||
|
"llmcompressor": {
|
||||||
|
"recipe": {
|
||||||
|
"finetuning_stage": {
|
||||||
|
"finetuning_modifiers": {
|
||||||
|
"ConstantPruningModifier": {
|
||||||
|
"targets": [
|
||||||
|
"re:.*q_proj.weight",
|
||||||
|
"re:.*k_proj.weight",
|
||||||
|
"re:.*v_proj.weight",
|
||||||
|
"re:.*o_proj.weight",
|
||||||
|
"re:.*gate_proj.weight",
|
||||||
|
"re:.*up_proj.weight",
|
||||||
|
"re:.*down_proj.weight",
|
||||||
|
],
|
||||||
|
"start": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"save_compressed": save_compressed,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
|
||||||
|
|
||||||
|
# recipe.yaml should exist
|
||||||
|
assert (Path(temp_dir) / "recipe.yaml").exists()
|
||||||
|
|
||||||
|
# sparsity config exists if save_compressed
|
||||||
|
if save_compressed:
|
||||||
|
from compressed_tensors import ModelCompressor
|
||||||
|
from compressed_tensors.config import Sparse24BitMaskConfig
|
||||||
|
|
||||||
|
compressor = ModelCompressor.from_pretrained(temp_dir)
|
||||||
|
assert compressor is not None
|
||||||
|
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)
|
||||||
Reference in New Issue
Block a user