Update: review comments!

This commit is contained in:
Rahul Tuli
2025-03-26 21:57:00 +00:00
committed by Rahul Tuli
parent f9d6776c28
commit 47a333ce49
4 changed files with 133 additions and 95 deletions

View File

@@ -60,17 +60,19 @@ fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0

View File

@@ -1,130 +1,146 @@
"""
Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks
Sparse Finetuning plugin for Axolotl enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
from transformers.training_args import TrainingArguments
from ..base import BasePlugin
from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401
from llmcompressor import initialize
from llmcompressor import active_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
class SFTCallbackHandler(TrainerCallback):
"""
Transformer trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps.
This ensures that optimizer updates to zero weights are canceled out.
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: object, recipe: object):
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the callback handler.
Initialize the Sparse Finetuning callback handler.
Args:
trainer (object): The trainer instance.
recipe (object): The sparse finetuning recipe to be applied.
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = Recipe.model_validate(recipe)
self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
if hasattr(self.trainer, "compute_loss"):
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
"""
Event triggered at the beginning of training.
Updates the session reference to the model, accommodating changes due to wrappers like FSDP.
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
initialize(
session = active_session()
session.initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_step_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
"""
Event triggered at the beginning of a training step.
Calls batch_start in the active CompressionSession.
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_step_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
"""
Event triggered at the end of a training step.
Calls optimizer pre-step, post-step, and batch_end callbacks.
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
"""
Event triggered at the end of a substep during gradient accumulation.
Calls batch_end in the active CompressionSession.
Called at the end of training. Finalizes the compression session.
"""
super().on_substep_end(args, state, control, **kwargs)
session_callbacks.batch_end()
# def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# super().on_prediction_step(args, state, control, **kwargs)
# session_callbacks.loss_calculated()
super().on_train_end(args, state, control, **kwargs)
session = active_session()
session.finalize()
class SFTPlugin(BasePlugin):
"""
Plugin for Sparse Finetuning integration with Axolotl.
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the input argument path for the plugin.
"""
return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs"
Returns the path to the plugin's argument definition.
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Adds Sparse Finetuning callback to the trainer.
Args:
cfg (object): Configuration object containing the recipe.
trainer (object): Trainer instance to which the callback is added.
Returns:
list: A list containing the Sparse Finetuning callback.
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = SFTCallbackHandler(
trainer=trainer,
recipe=cfg.recipe,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(compute_loss_func):
def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
"""
Wraps the loss computation function to integrate with the active CompressionSession.
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (function): The original loss computation function.
compute_loss_func (Callable): Original loss computation function.
Returns:
function: Wrapped function that reports the computed loss.
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
def wrapper(*args, **kwargs):
@wraps(compute_loss_func)
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(*args, **kwargs)
session_callbacks.loss_calculated(loss=loss)
# take the mean across multiple GPUs
# this is done outside the compute_loss function in the parent
loss = loss.mean()
return loss
return wrapper
return compute_and_notify

View File

@@ -1,13 +1,35 @@
"""
Pydantic model for accepting `llmcompressor` specific arguments.
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Optional, Any
from pydantic import BaseModel
from pydantic import BaseModel, Field, ConfigDict
from typing import Any
from typing_extensions import Annotated
class SFTArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(description="Recipe config.")
]
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_assignment=True,
)
class LLMCompressorArgs(BaseModel):
"""
Input arguments for Sparse Finetuning.
"""
recipe: Optional[Any] = None
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
SFTArgs,
Field(description="SFT llmcompressor args")
]
model_config = ConfigDict(
validate_assignment=True,
)

View File

@@ -140,22 +140,20 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
and model_config.quantization_config
)
# TODO: Use a better fix to handle
# config.json produced by compressed-tensors
# sparse-only model -> will also have a quantization_config
is_sparse_only_quant_config = bool(
not quant_config_exists
or (
quant_config_exists
and model_config.quantization_config["quant_method"] == "compressed-tensors"
and not model_config.quantization_config.get("config_groups", False)
and model_config.quantization_config.get("sparsity_config", False)
)
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_sparse_only_quant_config:
quant_config_exists = False
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warn(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = (
quant_config_exists