Compare commits

..

6 Commits

Author SHA1 Message Date
Wing Lian
1a22d16842 handle empty offset for quant state 2025-05-01 13:01:00 -04:00
Wing Lian
fee3c13bb5 Logging config for colab (#2611)
* only configure logging on cli to play nicely with colab

* allow reloading the config on the fly from a dict

* make sure to use dict for yaml

* reuse existing function for load

* make cli args optional

* mps fix and respect max_steps
2025-05-01 12:58:00 -04:00
Rahul Tuli
996fc124e5 Add: Sparse Finetuning Integration with llmcompressor (#2479)
* Add: SFTPlugin with llmcompressor

* Update: review comments!

* Add:llmcompressor instalable

* pre commit hooks

* Use: warning over warn

* Revert: TODO's

* Update llmcompressor version to latest

* Apply suggestions from @markurtz

Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>

* Address review comments from @markurtz

* Add: llcompressor installable

* Rename: sft.yaml to sparse-finetuning.yaml

* Use: absolute import

* Update model config

* Move: LLMCompressorPlugin into it's own submodule

* Add: `llm_compressor` integration documentation

* Rebase and updates!

* Tests, Style, Updates

* Add: .qmd file

* Address Review Comments:
* deleted redundant docs/llm_compressor.qmd
* incorporated feedback in integration README.md
* added llmcompressor integration to docs/custom_integrations.qmd

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* Add: line about further optimizations using llmcompressor

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* Apply patch from @winglian

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* Fix: Test

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* additional fixes for docker and saving compressed

* split llmcompressor from vllm checks

* Reset session between tests

Signed-off-by: Rahul Tuli <rtuli@redhat.com>

* move decorator to test method instead of class

* make sure to reset the session after each test

* move import of llmcompressor to reset session inside test

---------

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-01 12:25:16 -04:00
Wing Lian
e963990ad7 add missing __init__ for lr monkeypatch fix (#2609) 2025-05-01 09:41:32 -04:00
Dhruv Mullick
c3f2b1c5c2 Add num_completions_to_print for trl and grpo (#2604) 2025-04-30 21:00:30 -04:00
Wing Lian
6ba5c0ed2c use latest hf-xet and don't install vllm for torch 2.7.0 (#2603)
* use latest hf-xet and don't install vllm for torch 2.7.0

* fix runpod hub tests
2025-04-30 18:27:39 -04:00
23 changed files with 241 additions and 55 deletions

View File

@@ -30,7 +30,7 @@ jobs:
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
axolotl_extras: vllm axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -261,6 +261,18 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"

90
.runpod/tests.json Normal file
View File

@@ -0,0 +1,90 @@
{
"tests": [
{
"name": "quick_smoke_test_sft",
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
}
},
"timeout": 100000
}
],
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -49,7 +49,8 @@ sections = [
("Knowledge Distillation (KD)", "kd"), ("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"), ("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"), ("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum") ("Spectrum", "spectrum"),
("LLMCompressor", "llm_compressor")
] ]
for section_name, folder_name in sections: for section_name, folder_name in sections:

View File

@@ -18,7 +18,7 @@ accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0
deepspeed>=0.15.4 deepspeed>=0.15.4
trl==0.17.0 trl==0.17.0
hf_xet==1.0.0 hf_xet==1.1.0
hqq==0.2.5 hqq==0.2.5
optimum==1.16.2 optimum==1.16.2

View File

@@ -2,4 +2,7 @@
import os import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
configure_logging()

View File

@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@@ -5,6 +5,7 @@ import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union from typing import Union
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg plugin_manager.cfg = cfg
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
""" """
Loads the `axolotl` configuration stored at `config`, validates it, and performs Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup. various setup.
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
Returns: Returns:
`DictDefault` mapping configuration keys to values. `DictDefault` mapping configuration keys to values.
""" """
config = check_remote_config(config) if isinstance(config, (str, Path)):
if Path(config).is_dir(): config = check_remote_config(config)
config = choose_config(Path(config)) if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file # Load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
else:
cfg = config
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
temp_file.write(yaml.dump(config.to_dict()))
temp_file.close()
cfg.axolotl_config_path = temp_file.name
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
else: else:
cfg[k] = kwargs[k] cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)

View File

@@ -20,11 +20,9 @@ from transformers import (
ProcessorMixin, ProcessorMixin,
) )
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs], cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets, calling Loads one or more training or evaluation datasets, calling
@@ -64,7 +64,8 @@ def load_datasets(
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = ( preprocess_iterable = (
hasattr(cli_args, "iterable") cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None and cli_args.iterable is not None
and cli_args.iterable and cli_args.iterable
) )
@@ -76,7 +77,7 @@ def load_datasets(
preprocess_iterable=preprocess_iterable, preprocess_iterable=preprocess_iterable,
) )
if ( if cli_args and (
cli_args.debug cli_args.debug
or cfg.debug or cfg.debug
or cli_args.debug_text_only or cli_args.debug_text_only

View File

@@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1 self.cfg.max_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = ( training_arguments_kwargs["per_device_train_batch_size"] = (

View File

@@ -63,6 +63,7 @@ class GRPOStrategy:
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if trl.reward_weights: if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -11,7 +11,6 @@ from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.train import ( from axolotl.train import (
TrainDatasetMeta, TrainDatasetMeta,
setup_model_and_tokenizer, setup_model_and_tokenizer,
@@ -24,7 +23,6 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -45,6 +45,7 @@ llmcompressor:
're:.*down_proj.weight', 're:.*down_proj.weight',
] ]
start: 0 start: 0
save_compressed: true
# ... (other training arguments) # ... (other training arguments)
``` ```
@@ -52,19 +53,56 @@ This plugin **does not apply pruning or sparsification itself** — it is intend
Pre-sparsified checkpoints can be: Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor) - Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Or downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic) - Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation: To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md) [https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config ### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example. See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
--- ---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More ## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository: For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
👉 [https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) [https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -55,13 +55,16 @@ def dequantize(
target_device = W.device target_device = W.device
# Extract quantization state # Extract quantization state
nested = False
if not isinstance(quant_state, list): if not isinstance(quant_state, list):
# New style quant_state class # New style quant_state class
absmax = quant_state.absmax.to(target_device) absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape shape = quant_state.shape
dtype = quant_state.dtype dtype = quant_state.dtype
blocksize = quant_state.blocksize blocksize = quant_state.blocksize
offset = quant_state.offset.to(target_device) if quant_state.nested:
nested = True
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2 state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device) absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device) code2 = state2.code.to(target_device)
@@ -115,7 +118,8 @@ def dequantize(
ctypes.c_int(n_elements_absmax), ctypes.c_int(n_elements_absmax),
) )
out_absmax += offset if nested:
out_absmax += offset
# Choose appropriate dequantization function # Choose appropriate dequantization function
fx = ( fx = (

View File

@@ -12,10 +12,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -30,7 +30,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager, SequenceParallelContextManager,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
@@ -42,7 +41,6 @@ try:
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -288,7 +286,19 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import ( from axolotl.integrations.llm_compressor.utils import (
save_compressed_model, save_compressed_model,
) )
@@ -301,17 +311,6 @@ def save_trained_model(
save_compressed=cfg.llmcompressor.save_compressed, save_compressed=cfg.llmcompressor.save_compressed,
) )
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Trainer): def create_model_card(cfg: DictDefault, trainer: Trainer):
""" """

View File

@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
else: else:
LOG.debug("bf16 support not detected, disabling for this configuration.") LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False cfg.bf16 = False
if cfg.fp16 is None: if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True cfg.fp16 = True
if cfg.device == "mps": if cfg.device == "mps":

View File

@@ -67,6 +67,12 @@ class TRLConfig(BaseModel):
default=False, default=False,
json_schema_extra={"description": "Whether to log completions"}, json_schema_extra={"description": "Whether to log completions"},
) )
num_completions_to_print: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
},
)
sync_ref_model: bool | None = Field( sync_ref_model: bool | None = Field(
default=False, default=False,
json_schema_extra={ json_schema_extra={

View File

@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16: elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg): def prepare_opinionated_env(cfg):

View File

@@ -9,10 +9,14 @@ import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1 from tests.e2e.utils import (
check_model_output_exists,
require_llmcompressor,
require_torch_2_4_1,
)
MODELS = [ MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed", "nm-testing/llama2.c-stories42M-pruned2.4-compressed",
@@ -31,10 +35,13 @@ class TestLLMCompressorIntegration:
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
""" """
@require_llmcompressor
@require_torch_2_4_1 @require_torch_2_4_1
def test_llmcompressor_plugin( def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool self, temp_dir, base_model: str, save_compressed: bool
): ):
from llmcompressor import active_session
# core cfg # core cfg
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -79,22 +86,23 @@ class TestLLMCompressorIntegration:
) )
prepare_plugins(cfg) prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) try:
check_model_output_exists(temp_dir, cfg) train(cfg=cfg, dataset_meta=dataset_meta)
_check_llmcompressor_model_outputs(temp_dir, save_compressed) check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
finally:
active_session().reset()
def _check_llmcompressor_model_outputs(temp_dir, save_compressed): def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
# recipe.yaml should exist
assert (Path(temp_dir) / "recipe.yaml").exists()
# sparsity config exists if save_compressed
if save_compressed: if save_compressed:
assert (Path(temp_dir) / "recipe.yaml").exists()
from compressed_tensors import ModelCompressor from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig from compressed_tensors.config import Sparse24BitMaskConfig

View File

@@ -105,7 +105,25 @@ def require_vllm(test_case):
return False return False
return unittest.skipUnless( return unittest.skipUnless(
is_vllm_installed(), "test requires a vllm to be installed" is_vllm_installed(), "test requires vllm to be installed"
)(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
)(test_case) )(test_case)