Update: review comments!

This commit is contained in:
Rahul Tuli
2025-03-26 21:57:00 +00:00
committed by Rahul Tuli
parent 7946f89df4
commit b76d2d1130
4 changed files with 133 additions and 95 deletions

View File

@@ -60,17 +60,19 @@ fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
recipe:
finetuning_stage: llmcompressor:
finetuning_modifiers: recipe:
ConstantPruningModifier: finetuning_stage:
targets: [ finetuning_modifiers:
're:.*q_proj.weight', ConstantPruningModifier:
're:.*k_proj.weight', targets: [
're:.*v_proj.weight', 're:.*q_proj.weight',
're:.*o_proj.weight', 're:.*k_proj.weight',
're:.*gate_proj.weight', 're:.*v_proj.weight',
're:.*up_proj.weight', 're:.*o_proj.weight',
're:.*down_proj.weight', 're:.*gate_proj.weight',
] 're:.*up_proj.weight',
start: 0 're:.*down_proj.weight',
]
start: 0

View File

@@ -1,130 +1,146 @@
""" """
Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks Sparse Finetuning plugin for Axolotl enables handling of sparse neural networks
by maintaining masks for zero weights during training. by maintaining masks for zero weights during training.
""" """
import logging import logging
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
from transformers.training_args import TrainingArguments from transformers.training_args import TrainingArguments
from ..base import BasePlugin from ..base import BasePlugin
from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401 from llmcompressor import active_session
from llmcompressor import initialize
from llmcompressor.core import callbacks as session_callbacks from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe from llmcompressor.recipe import Recipe
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft") LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
class SFTCallbackHandler(TrainerCallback): class SFTCallbackHandler(TrainerCallback):
""" """
Transformer trainer callback for Sparse Finetuning. Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps. Maintains sparsity patterns during training by applying masks after optimization steps,
This ensures that optimizer updates to zero weights are canceled out. ensuring zero-weight updates are canceled out.
""" """
def __init__(self, trainer: object, recipe: object): def __init__(self, trainer: Trainer, recipe: Any):
""" """
Initialize the callback handler. Initialize the Sparse Finetuning callback handler.
Args: Args:
trainer (object): The trainer instance. trainer (Trainer): Huggingface Trainer instance.
recipe (object): The sparse finetuning recipe to be applied. recipe (Recipe | dict): Sparse finetuning recipe to apply.
""" """
super().__init__() super().__init__()
self.trainer = trainer self.trainer = trainer
self.recipe = Recipe.model_validate(recipe) self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
if hasattr(self.trainer, "compute_loss"): def on_train_begin(
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
""" """
Event triggered at the beginning of training. Called at the beginning of training. Initializes the compression session.
Updates the session reference to the model, accommodating changes due to wrappers like FSDP.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
""" """
super().on_train_begin(args, state, control, **kwargs) super().on_train_begin(args, state, control, **kwargs)
initialize( session = active_session()
session.initialize(
model=self.trainer.model, model=self.trainer.model,
optimizer=self.trainer.optimizer, optimizer=self.trainer.optimizer,
start=state.epoch, start=state.epoch,
recipe=self.recipe, recipe=self.recipe,
) )
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_step_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
""" """
Event triggered at the beginning of a training step. Called at the beginning of a training step. Triggers batch_start callback.
Calls batch_start in the active CompressionSession.
""" """
super().on_step_begin(args, state, control, **kwargs) super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start() session_callbacks.batch_start()
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_step_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
""" """
Event triggered at the end of a training step. Called at the end of a training step. Triggers optimizer and batch_end callbacks.
Calls optimizer pre-step, post-step, and batch_end callbacks.
""" """
super().on_step_end(args, state, control, **kwargs) super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step() session_callbacks.optim_pre_step()
session_callbacks.optim_post_step() session_callbacks.optim_post_step()
session_callbacks.batch_end() session_callbacks.batch_end()
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_train_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
) -> None:
""" """
Event triggered at the end of a substep during gradient accumulation. Called at the end of training. Finalizes the compression session.
Calls batch_end in the active CompressionSession.
""" """
super().on_substep_end(args, state, control, **kwargs) super().on_train_end(args, state, control, **kwargs)
session_callbacks.batch_end() session = active_session()
session.finalize()
# def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# super().on_prediction_step(args, state, control, **kwargs)
# session_callbacks.loss_calculated()
class SFTPlugin(BasePlugin): class SFTPlugin(BasePlugin):
""" """
Plugin for Sparse Finetuning integration with Axolotl. Sparse Finetuning plugin for Axolotl integration.
""" """
def get_input_args(self) -> str: def get_input_args(self) -> str:
""" """
Returns the input argument path for the plugin. Returns the path to the plugin's argument definition.
"""
return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Adds Sparse Finetuning callback to the trainer.
Args:
cfg (object): Configuration object containing the recipe.
trainer (object): Trainer instance to which the callback is added.
Returns: Returns:
list: A list containing the Sparse Finetuning callback. str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
""" """
LOG.info("Adding Sparse Finetuning callback to the trainer") LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = SFTCallbackHandler( callback = SFTCallbackHandler(
trainer=trainer, trainer=trainer,
recipe=cfg.recipe, recipe=cfg.llmcompressor.recipe,
) )
return [callback] return [callback]
def compute_loss_wrapper(compute_loss_func): def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
""" """
Wraps the loss computation function to integrate with the active CompressionSession. Wraps the loss computation function to trigger the loss_calculated callback.
Args: Args:
compute_loss_func (function): The original loss computation function. compute_loss_func (Callable): Original loss computation function.
Returns: Returns:
function: Wrapped function that reports the computed loss. Callable: Wrapped function that also invokes the loss_calculated callback.
""" """
def wrapper(*args, **kwargs):
@wraps(compute_loss_func)
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(*args, **kwargs) loss = compute_loss_func(*args, **kwargs)
session_callbacks.loss_calculated(loss=loss) session_callbacks.loss_calculated(loss=loss)
# take the mean across multiple GPUs
# this is done outside the compute_loss function in the parent
loss = loss.mean()
return loss return loss
return wrapper
return compute_and_notify

View File

@@ -1,13 +1,35 @@
""" """
Pydantic model for accepting `llmcompressor` specific arguments. LLMCompressor and Sparse Finetuning config models.
""" """
from typing import Optional, Any
from pydantic import BaseModel from pydantic import BaseModel, Field, ConfigDict
from typing import Any
from typing_extensions import Annotated
class SFTArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(description="Recipe config.")
]
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_assignment=True,
)
class LLMCompressorArgs(BaseModel): class LLMCompressorArgs(BaseModel):
""" """LLMCompressor configuration BaseModel."""
Input arguments for Sparse Finetuning.
""" llmcompressor: Annotated[
SFTArgs,
recipe: Optional[Any] = None Field(description="SFT llmcompressor args")
]
model_config = ConfigDict(
validate_assignment=True,
)

View File

@@ -140,22 +140,20 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
and model_config.quantization_config and model_config.quantization_config
) )
# TODO: Use a better fix to handle # Detect compressed-tensors config
# config.json produced by compressed-tensors is_compressed_tensors_config = (
# sparse-only model -> will also have a quantization_config quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
is_sparse_only_quant_config = bool(
not quant_config_exists
or (
quant_config_exists
and model_config.quantization_config["quant_method"] == "compressed-tensors"
and not model_config.quantization_config.get("config_groups", False)
and model_config.quantization_config.get("sparsity_config", False)
)
) )
if is_sparse_only_quant_config: if is_compressed_tensors_config:
quant_config_exists = False if model_config.quantization_config.get("config_groups"):
LOG.warn(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = ( quant_config_method_is_gptq = (
quant_config_exists quant_config_exists