Compare commits

..

17 Commits

Author SHA1 Message Date
Rahul Tuli
3a8b637598 Tests, Style, Updates 2025-04-30 17:21:52 -04:00
Rahul Tuli
12cd09e6f5 Rebase and updates! 2025-04-30 17:21:52 -04:00
Rahul Tuli
fe82f62248 Add: llm_compressor integration documentation 2025-04-30 17:21:52 -04:00
Rahul Tuli
db31d7ad22 Move: LLMCompressorPlugin into it's own submodule 2025-04-30 17:21:52 -04:00
Rahul Tuli
eb7f2aa4b9 Update model config 2025-04-30 17:21:51 -04:00
Rahul Tuli
f80e36ddd2 Use: absolute import 2025-04-30 17:21:51 -04:00
Rahul Tuli
412d2ec6d0 Rename: sft.yaml to sparse-finetuning.yaml 2025-04-30 17:21:51 -04:00
Rahul Tuli
50fc5e6984 Add: llcompressor installable 2025-04-30 17:21:51 -04:00
Rahul Tuli
83a88b745f Address review comments from @markurtz 2025-04-30 17:21:51 -04:00
Rahul Tuli
8855bb115f Apply suggestions from @markurtz
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
2025-04-30 17:21:51 -04:00
Rahul Tuli
ef9543b371 Update llmcompressor version to latest 2025-04-30 17:21:51 -04:00
Rahul Tuli
25e701e885 Revert: TODO's 2025-04-30 17:21:50 -04:00
Rahul Tuli
891a21e599 Use: warning over warn 2025-04-30 17:21:50 -04:00
Rahul Tuli
8beb2f27ad pre commit hooks 2025-04-30 17:21:50 -04:00
Rahul Tuli
56ba66b60f Add:llmcompressor instalable 2025-04-30 17:21:50 -04:00
Rahul Tuli
13d4b865d6 Update: review comments! 2025-04-30 17:21:50 -04:00
Rahul Tuli
3da866b2b9 Add: SFTPlugin with llmcompressor 2025-04-30 17:21:50 -04:00
23 changed files with 55 additions and 241 deletions

View File

@@ -30,7 +30,7 @@ jobs:
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
axolotl_extras: axolotl_extras: vllm
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -261,18 +261,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"

View File

@@ -1,90 +0,0 @@
{
"tests": [
{
"name": "quick_smoke_test_sft",
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
}
},
"timeout": 100000
}
],
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -49,8 +49,7 @@ sections = [
("Knowledge Distillation (KD)", "kd"), ("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"), ("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"), ("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum"), ("Spectrum", "spectrum")
("LLMCompressor", "llm_compressor")
] ]
for section_name, folder_name in sections: for section_name, folder_name in sections:

View File

@@ -18,7 +18,7 @@ accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0
deepspeed>=0.15.4 deepspeed>=0.15.4
trl==0.17.0 trl==0.17.0
hf_xet==1.1.0 hf_xet==1.0.0
hqq==0.2.5 hqq==0.2.5
optimum==1.16.2 optimum==1.16.2

View File

@@ -2,7 +2,4 @@
import os import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
configure_logging()

View File

@@ -8,6 +8,9 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@@ -5,7 +5,6 @@ import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union from typing import Union
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -159,9 +158,7 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg plugin_manager.cfg = cfg
def load_cfg( def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
""" """
Loads the `axolotl` configuration stored at `config`, validates it, and performs Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup. various setup.
@@ -173,24 +170,13 @@ def load_cfg(
Returns: Returns:
`DictDefault` mapping configuration keys to values. `DictDefault` mapping configuration keys to values.
""" """
if isinstance(config, (str, Path)): config = check_remote_config(config)
config = check_remote_config(config) if Path(config).is_dir():
if Path(config).is_dir(): config = choose_config(Path(config))
config = choose_config(Path(config))
# Load the config from the yaml file # Load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
else:
cfg = config
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
temp_file.write(yaml.dump(config.to_dict()))
temp_file.close()
cfg.axolotl_config_path = temp_file.name
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
@@ -204,6 +190,8 @@ def load_cfg(
else: else:
cfg[k] = kwargs[k] cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)

View File

@@ -20,9 +20,11 @@ from transformers import (
ProcessorMixin, ProcessorMixin,
) )
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None, cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets, calling Loads one or more training or evaluation datasets, calling
@@ -64,8 +64,7 @@ def load_datasets(
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = ( preprocess_iterable = (
cli_args hasattr(cli_args, "iterable")
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None and cli_args.iterable is not None
and cli_args.iterable and cli_args.iterable
) )
@@ -77,7 +76,7 @@ def load_datasets(
preprocess_iterable=preprocess_iterable, preprocess_iterable=preprocess_iterable,
) )
if cli_args and ( if (
cli_args.debug cli_args.debug
or cfg.debug or cfg.debug
or cli_args.debug_text_only or cli_args.debug_text_only

View File

@@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
self.cfg.max_steps if self.cfg.max_steps else -1 total_num_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = ( training_arguments_kwargs["per_device_train_batch_size"] = (

View File

@@ -63,7 +63,6 @@ class GRPOStrategy:
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if trl.reward_weights: if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -11,6 +11,7 @@ from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.train import ( from axolotl.train import (
TrainDatasetMeta, TrainDatasetMeta,
setup_model_and_tokenizer, setup_model_and_tokenizer,
@@ -23,6 +24,7 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -45,7 +45,6 @@ llmcompressor:
're:.*down_proj.weight', 're:.*down_proj.weight',
] ]
start: 0 start: 0
save_compressed: true
# ... (other training arguments) # ... (other training arguments)
``` ```
@@ -53,56 +52,19 @@ This plugin **does not apply pruning or sparsification itself** — it is intend
Pre-sparsified checkpoints can be: Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor) - Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic) - Or downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation: To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md) [https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config ### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example. See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
--- ---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More ## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository: For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) 👉 [https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -55,16 +55,13 @@ def dequantize(
target_device = W.device target_device = W.device
# Extract quantization state # Extract quantization state
nested = False
if not isinstance(quant_state, list): if not isinstance(quant_state, list):
# New style quant_state class # New style quant_state class
absmax = quant_state.absmax.to(target_device) absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape shape = quant_state.shape
dtype = quant_state.dtype dtype = quant_state.dtype
blocksize = quant_state.blocksize blocksize = quant_state.blocksize
if quant_state.nested: offset = quant_state.offset.to(target_device)
nested = True
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2 state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device) absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device) code2 = state2.code.to(target_device)
@@ -118,8 +115,7 @@ def dequantize(
ctypes.c_int(n_elements_absmax), ctypes.c_int(n_elements_absmax),
) )
if nested: out_absmax += offset
out_absmax += offset
# Choose appropriate dequantization function # Choose appropriate dequantization function
fx = ( fx = (

View File

@@ -12,8 +12,10 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -30,6 +30,7 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager, SequenceParallelContextManager,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
@@ -41,6 +42,7 @@ try:
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -286,19 +288,7 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif cfg.local_rank == 0: elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import ( from axolotl.integrations.llm_compressor.utils import (
save_compressed_model, save_compressed_model,
) )
@@ -311,6 +301,17 @@ def save_trained_model(
save_compressed=cfg.llmcompressor.save_compressed, save_compressed=cfg.llmcompressor.save_compressed,
) )
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Trainer): def create_model_card(cfg: DictDefault, trainer: Trainer):
""" """

View File

@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
else: else:
LOG.debug("bf16 support not detected, disabling for this configuration.") LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False cfg.bf16 = False
if cfg.fp16 is None and not cfg.float16: if cfg.fp16 is None:
cfg.fp16 = True cfg.fp16 = True
if cfg.device == "mps": if cfg.device == "mps":

View File

@@ -67,12 +67,6 @@ class TRLConfig(BaseModel):
default=False, default=False,
json_schema_extra={"description": "Whether to log completions"}, json_schema_extra={"description": "Whether to log completions"},
) )
num_completions_to_print: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
},
)
sync_ref_model: bool | None = Field( sync_ref_model: bool | None = Field(
default=False, default=False,
json_schema_extra={ json_schema_extra={

View File

@@ -597,8 +597,6 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16: elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg): def prepare_opinionated_env(cfg):

View File

@@ -9,14 +9,10 @@ import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import ( from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
check_model_output_exists,
require_llmcompressor,
require_torch_2_4_1,
)
MODELS = [ MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed", "nm-testing/llama2.c-stories42M-pruned2.4-compressed",
@@ -35,13 +31,10 @@ class TestLLMCompressorIntegration:
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
""" """
@require_llmcompressor
@require_torch_2_4_1 @require_torch_2_4_1
def test_llmcompressor_plugin( def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool self, temp_dir, base_model: str, save_compressed: bool
): ):
from llmcompressor import active_session
# core cfg # core cfg
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -86,23 +79,22 @@ class TestLLMCompressorIntegration:
) )
prepare_plugins(cfg) prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
try: train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg)
check_model_output_exists(temp_dir, cfg) _check_llmcompressor_model_outputs(temp_dir, save_compressed)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
finally:
active_session().reset()
def _check_llmcompressor_model_outputs(temp_dir, save_compressed): def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
if save_compressed:
assert (Path(temp_dir) / "recipe.yaml").exists()
# recipe.yaml should exist
assert (Path(temp_dir) / "recipe.yaml").exists()
# sparsity config exists if save_compressed
if save_compressed:
from compressed_tensors import ModelCompressor from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig from compressed_tensors.config import Sparse24BitMaskConfig

View File

@@ -105,25 +105,7 @@ def require_vllm(test_case):
return False return False
return unittest.skipUnless( return unittest.skipUnless(
is_vllm_installed(), "test requires vllm to be installed" is_vllm_installed(), "test requires a vllm to be installed"
)(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
)(test_case) )(test_case)