From 47a333ce49e5af9ad2f0808621e16270ca022161 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 26 Mar 2025 21:57:00 +0000 Subject: [PATCH] Update: review comments! --- examples/llama-3/sft.yaml | 30 ++-- .../llmcompressor_sft/__init__.py | 134 ++++++++++-------- .../integrations/llmcompressor_sft/args.py | 38 +++-- src/axolotl/utils/models.py | 26 ++-- 4 files changed, 133 insertions(+), 95 deletions(-) diff --git a/examples/llama-3/sft.yaml b/examples/llama-3/sft.yaml index 5077be697..0ec32626f 100644 --- a/examples/llama-3/sft.yaml +++ b/examples/llama-3/sft.yaml @@ -60,17 +60,19 @@ fsdp: fsdp_config: special_tokens: pad_token: <|end_of_text|> -recipe: - finetuning_stage: - finetuning_modifiers: - ConstantPruningModifier: - targets: [ - 're:.*q_proj.weight', - 're:.*k_proj.weight', - 're:.*v_proj.weight', - 're:.*o_proj.weight', - 're:.*gate_proj.weight', - 're:.*up_proj.weight', - 're:.*down_proj.weight', - ] - start: 0 \ No newline at end of file + +llmcompressor: + recipe: + finetuning_stage: + finetuning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 \ No newline at end of file diff --git a/src/axolotl/integrations/llmcompressor_sft/__init__.py b/src/axolotl/integrations/llmcompressor_sft/__init__.py index c888f6797..b99daef47 100644 --- a/src/axolotl/integrations/llmcompressor_sft/__init__.py +++ b/src/axolotl/integrations/llmcompressor_sft/__init__.py @@ -1,130 +1,146 @@ """ -Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks +Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks by maintaining masks for zero weights during training. """ import logging +from functools import wraps +from typing import Callable, TypeVar, ParamSpec, Any + +from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl from transformers.training_args import TrainingArguments from ..base import BasePlugin -from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401 -from llmcompressor import initialize +from llmcompressor import active_session from llmcompressor.core import callbacks as session_callbacks from llmcompressor.recipe import Recipe +P = ParamSpec("P") # Params for generic function signatures +R = TypeVar("R") # Return type for generic function signatures + LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft") + class SFTCallbackHandler(TrainerCallback): """ - Transformer trainer callback for Sparse Finetuning. - Maintains sparsity patterns during training by applying masks after optimization steps. - This ensures that optimizer updates to zero weights are canceled out. + Trainer callback for Sparse Finetuning. + Maintains sparsity patterns during training by applying masks after optimization steps, + ensuring zero-weight updates are canceled out. """ - def __init__(self, trainer: object, recipe: object): + def __init__(self, trainer: Trainer, recipe: Any): """ - Initialize the callback handler. - + Initialize the Sparse Finetuning callback handler. + Args: - trainer (object): The trainer instance. - recipe (object): The sparse finetuning recipe to be applied. + trainer (Trainer): Huggingface Trainer instance. + recipe (Recipe | dict): Sparse finetuning recipe to apply. """ super().__init__() self.trainer = trainer - self.recipe = Recipe.model_validate(recipe) + self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe + self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) - if hasattr(self.trainer, "compute_loss"): - self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) - - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_train_begin( + self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + ) -> None: """ - Event triggered at the beginning of training. - Updates the session reference to the model, accommodating changes due to wrappers like FSDP. + Called at the beginning of training. Initializes the compression session. + + Args: + args (TrainingArguments): Training arguments. + state (TrainerState): Trainer state. + control (TrainerControl): Trainer control. """ super().on_train_begin(args, state, control, **kwargs) - initialize( + session = active_session() + session.initialize( model=self.trainer.model, optimizer=self.trainer.optimizer, start=state.epoch, recipe=self.recipe, ) - def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_step_begin( + self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + ) -> None: """ - Event triggered at the beginning of a training step. - Calls batch_start in the active CompressionSession. + Called at the beginning of a training step. Triggers batch_start callback. """ super().on_step_begin(args, state, control, **kwargs) session_callbacks.batch_start() - def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_step_end( + self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + ) -> None: """ - Event triggered at the end of a training step. - Calls optimizer pre-step, post-step, and batch_end callbacks. + Called at the end of a training step. Triggers optimizer and batch_end callbacks. """ super().on_step_end(args, state, control, **kwargs) session_callbacks.optim_pre_step() session_callbacks.optim_post_step() session_callbacks.batch_end() - def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_train_end( + self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + ) -> None: """ - Event triggered at the end of a substep during gradient accumulation. - Calls batch_end in the active CompressionSession. + Called at the end of training. Finalizes the compression session. """ - super().on_substep_end(args, state, control, **kwargs) - session_callbacks.batch_end() - - # def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - # super().on_prediction_step(args, state, control, **kwargs) - # session_callbacks.loss_calculated() + super().on_train_end(args, state, control, **kwargs) + session = active_session() + session.finalize() + class SFTPlugin(BasePlugin): """ - Plugin for Sparse Finetuning integration with Axolotl. + Sparse Finetuning plugin for Axolotl integration. """ def get_input_args(self) -> str: """ - Returns the input argument path for the plugin. - """ - return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs" + Returns the path to the plugin's argument definition. - def add_callbacks_post_trainer(self, cfg, trainer): - """ - Adds Sparse Finetuning callback to the trainer. - - Args: - cfg (object): Configuration object containing the recipe. - trainer (object): Trainer instance to which the callback is added. - Returns: - list: A list containing the Sparse Finetuning callback. + str: Dotted path to the LLMCompressorArgs class. + """ + return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs" + + def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: + """ + Adds Sparse Finetuning callback to the Trainer instance. + + Args: + cfg (Any): Configuration object containing the sparse recipe. + trainer (Trainer): Huggingface Trainer instance. + + Returns: + list: List containing the configured callback instances. """ LOG.info("Adding Sparse Finetuning callback to the trainer") callback = SFTCallbackHandler( trainer=trainer, - recipe=cfg.recipe, + recipe=cfg.llmcompressor.recipe, ) return [callback] -def compute_loss_wrapper(compute_loss_func): +def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]: """ - Wraps the loss computation function to integrate with the active CompressionSession. - + Wraps the loss computation function to trigger the loss_calculated callback. + Args: - compute_loss_func (function): The original loss computation function. - + compute_loss_func (Callable): Original loss computation function. + Returns: - function: Wrapped function that reports the computed loss. + Callable: Wrapped function that also invokes the loss_calculated callback. """ - def wrapper(*args, **kwargs): + + @wraps(compute_loss_func) + def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R: loss = compute_loss_func(*args, **kwargs) session_callbacks.loss_calculated(loss=loss) - # take the mean across multiple GPUs - # this is done outside the compute_loss function in the parent - loss = loss.mean() return loss - return wrapper \ No newline at end of file + + return compute_and_notify diff --git a/src/axolotl/integrations/llmcompressor_sft/args.py b/src/axolotl/integrations/llmcompressor_sft/args.py index a117b862c..24106bc32 100644 --- a/src/axolotl/integrations/llmcompressor_sft/args.py +++ b/src/axolotl/integrations/llmcompressor_sft/args.py @@ -1,13 +1,35 @@ """ -Pydantic model for accepting `llmcompressor` specific arguments. +LLMCompressor and Sparse Finetuning config models. """ -from typing import Optional, Any -from pydantic import BaseModel + +from pydantic import BaseModel, Field, ConfigDict +from typing import Any +from typing_extensions import Annotated + +class SFTArgs(BaseModel): + """Sparse Finetuning config for LLMCompressor.""" + + # Typing for recipe is set to Any due to: + # https://github.com/vllm-project/llm-compressor/issues/1319 + recipe: Annotated[ + Any, + Field(description="Recipe config.") + ] + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + ) class LLMCompressorArgs(BaseModel): - """ - Input arguments for Sparse Finetuning. - """ - - recipe: Optional[Any] = None \ No newline at end of file + """LLMCompressor configuration BaseModel.""" + + llmcompressor: Annotated[ + SFTArgs, + Field(description="SFT llmcompressor args") + ] + + model_config = ConfigDict( + validate_assignment=True, + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 61e6eea3b..cd8c8a9d9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -140,22 +140,20 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): and model_config.quantization_config ) - # TODO: Use a better fix to handle - # config.json produced by compressed-tensors - # sparse-only model -> will also have a quantization_config - - is_sparse_only_quant_config = bool( - not quant_config_exists - or ( - quant_config_exists - and model_config.quantization_config["quant_method"] == "compressed-tensors" - and not model_config.quantization_config.get("config_groups", False) - and model_config.quantization_config.get("sparsity_config", False) - ) + # Detect compressed-tensors config + is_compressed_tensors_config = ( + quant_config_exists + and model_config.quantization_config.get("quant_method") == "compressed-tensors" ) - if is_sparse_only_quant_config: - quant_config_exists = False + if is_compressed_tensors_config: + if model_config.quantization_config.get("config_groups"): + LOG.warn( + "Found `config_groups` in a compressed-tensors config. " + "QAT integration with llmcompressor is not tested." + ) + # Skip further quant checks for compressed-tensors + return quant_config_method_is_gptq = ( quant_config_exists