From db31d7ad227a064cd96c7c05fed45f7d8681dc16 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 9 Apr 2025 00:27:48 +0000 Subject: [PATCH] Move: LLMCompressorPlugin into it's own submodule --- .../integrations/llm_compressor/__init__.py | 165 +----------------- .../integrations/llm_compressor/plugin.py | 164 +++++++++++++++++ 2 files changed, 167 insertions(+), 162 deletions(-) create mode 100644 src/axolotl/integrations/llm_compressor/plugin.py diff --git a/src/axolotl/integrations/llm_compressor/__init__.py b/src/axolotl/integrations/llm_compressor/__init__.py index d4797b7c2..fe799d3c0 100644 --- a/src/axolotl/integrations/llm_compressor/__init__.py +++ b/src/axolotl/integrations/llm_compressor/__init__.py @@ -1,164 +1,5 @@ -""" -Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks -by maintaining masks for zero weights during training. -""" +"""Integration entry point for the LLMCompressor plugin.""" -import logging -from functools import wraps -from typing import Any, Callable, ParamSpec, TypeVar +from .plugin import LLMCompressorPlugin -from llmcompressor import active_session -from llmcompressor.core import callbacks as session_callbacks -from llmcompressor.recipe import Recipe -from transformers.trainer import Trainer -from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState -from transformers.training_args import TrainingArguments - -from axolotl.integrations.base import BasePlugin - -P = ParamSpec("P") # Params for generic function signatures -R = TypeVar("R") # Return type for generic function signatures - -LOG = logging.getLogger("axolotl.integrations.llm_compressor") - - -class LLMCompressorCallbackHandler(TrainerCallback): - """ - Trainer callback for Sparse Finetuning. - Maintains sparsity patterns during training by applying masks after optimization steps, - ensuring zero-weight updates are canceled out. - """ - - def __init__(self, trainer: Trainer, recipe: Any): - """ - Initialize the Sparse Finetuning callback handler. - - Args: - trainer (Trainer): Huggingface Trainer instance. - recipe (Recipe | dict): Sparse finetuning recipe to apply. - """ - super().__init__() - self.trainer = trainer - self.recipe = ( - Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe - ) - self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) - - def on_train_begin( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ) -> None: - """ - Called at the beginning of training. Initializes the compression session. - - Args: - args (TrainingArguments): Training arguments. - state (TrainerState): Trainer state. - control (TrainerControl): Trainer control. - """ - super().on_train_begin(args, state, control, **kwargs) - session = active_session() - session.initialize( - model=self.trainer.model, - optimizer=self.trainer.optimizer, - start=state.epoch, - recipe=self.recipe, - ) - - def on_step_begin( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ) -> None: - """ - Called at the beginning of a training step. Triggers batch_start callback. - """ - super().on_step_begin(args, state, control, **kwargs) - session_callbacks.batch_start() - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ) -> None: - """ - Called at the end of a training step. Triggers optimizer and batch_end callbacks. - """ - super().on_step_end(args, state, control, **kwargs) - session_callbacks.optim_pre_step() - session_callbacks.optim_post_step() - session_callbacks.batch_end() - - def on_train_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ) -> None: - """ - Called at the end of training. Finalizes the compression session. - """ - super().on_train_end(args, state, control, **kwargs) - session = active_session() - session.finalize() - - -class LLMCompressorPlugin(BasePlugin): - """ - Sparse Finetuning plugin for Axolotl integration. - """ - - def get_input_args(self) -> str: - """ - Returns the path to the plugin's argument definition. - - Returns: - str: Dotted path to the LLMCompressorArgs class. - """ - return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs" - - def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: - """ - Adds Sparse Finetuning callback to the Trainer instance. - - Args: - cfg (Any): Configuration object containing the sparse recipe. - trainer (Trainer): Huggingface Trainer instance. - - Returns: - list: List containing the configured callback instances. - """ - LOG.info("Adding Sparse Finetuning callback to the trainer") - callback = LLMCompressorCallbackHandler( - trainer=trainer, - recipe=cfg.llmcompressor.recipe, - ) - return [callback] - - -def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]: - """ - Wraps the loss computation function to trigger the loss_calculated callback. - - Args: - compute_loss_func (Callable): Original loss computation function. - - Returns: - Callable: Wrapped function that also invokes the loss_calculated callback. - """ - - @wraps(compute_loss_func) - def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R: - loss = compute_loss_func(*args, **kwargs) - session_callbacks.loss_calculated(loss=loss) - return loss - - return compute_and_notify +__all__ = ["LLMCompressorPlugin"] diff --git a/src/axolotl/integrations/llm_compressor/plugin.py b/src/axolotl/integrations/llm_compressor/plugin.py new file mode 100644 index 000000000..d4797b7c2 --- /dev/null +++ b/src/axolotl/integrations/llm_compressor/plugin.py @@ -0,0 +1,164 @@ +""" +Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks +by maintaining masks for zero weights during training. +""" + +import logging +from functools import wraps +from typing import Any, Callable, ParamSpec, TypeVar + +from llmcompressor import active_session +from llmcompressor.core import callbacks as session_callbacks +from llmcompressor.recipe import Recipe +from transformers.trainer import Trainer +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from axolotl.integrations.base import BasePlugin + +P = ParamSpec("P") # Params for generic function signatures +R = TypeVar("R") # Return type for generic function signatures + +LOG = logging.getLogger("axolotl.integrations.llm_compressor") + + +class LLMCompressorCallbackHandler(TrainerCallback): + """ + Trainer callback for Sparse Finetuning. + Maintains sparsity patterns during training by applying masks after optimization steps, + ensuring zero-weight updates are canceled out. + """ + + def __init__(self, trainer: Trainer, recipe: Any): + """ + Initialize the Sparse Finetuning callback handler. + + Args: + trainer (Trainer): Huggingface Trainer instance. + recipe (Recipe | dict): Sparse finetuning recipe to apply. + """ + super().__init__() + self.trainer = trainer + self.recipe = ( + Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe + ) + self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the beginning of training. Initializes the compression session. + + Args: + args (TrainingArguments): Training arguments. + state (TrainerState): Trainer state. + control (TrainerControl): Trainer control. + """ + super().on_train_begin(args, state, control, **kwargs) + session = active_session() + session.initialize( + model=self.trainer.model, + optimizer=self.trainer.optimizer, + start=state.epoch, + recipe=self.recipe, + ) + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the beginning of a training step. Triggers batch_start callback. + """ + super().on_step_begin(args, state, control, **kwargs) + session_callbacks.batch_start() + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the end of a training step. Triggers optimizer and batch_end callbacks. + """ + super().on_step_end(args, state, control, **kwargs) + session_callbacks.optim_pre_step() + session_callbacks.optim_post_step() + session_callbacks.batch_end() + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ) -> None: + """ + Called at the end of training. Finalizes the compression session. + """ + super().on_train_end(args, state, control, **kwargs) + session = active_session() + session.finalize() + + +class LLMCompressorPlugin(BasePlugin): + """ + Sparse Finetuning plugin for Axolotl integration. + """ + + def get_input_args(self) -> str: + """ + Returns the path to the plugin's argument definition. + + Returns: + str: Dotted path to the LLMCompressorArgs class. + """ + return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs" + + def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: + """ + Adds Sparse Finetuning callback to the Trainer instance. + + Args: + cfg (Any): Configuration object containing the sparse recipe. + trainer (Trainer): Huggingface Trainer instance. + + Returns: + list: List containing the configured callback instances. + """ + LOG.info("Adding Sparse Finetuning callback to the trainer") + callback = LLMCompressorCallbackHandler( + trainer=trainer, + recipe=cfg.llmcompressor.recipe, + ) + return [callback] + + +def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]: + """ + Wraps the loss computation function to trigger the loss_calculated callback. + + Args: + compute_loss_func (Callable): Original loss computation function. + + Returns: + Callable: Wrapped function that also invokes the loss_calculated callback. + """ + + @wraps(compute_loss_func) + def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R: + loss = compute_loss_func(*args, **kwargs) + session_callbacks.loss_calculated(loss=loss) + return loss + + return compute_and_notify