Update: review comments!
This commit is contained in:
@@ -60,17 +60,19 @@ fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
recipe:
|
||||
finetuning_stage:
|
||||
finetuning_modifiers:
|
||||
ConstantPruningModifier:
|
||||
targets: [
|
||||
're:.*q_proj.weight',
|
||||
're:.*k_proj.weight',
|
||||
're:.*v_proj.weight',
|
||||
're:.*o_proj.weight',
|
||||
're:.*gate_proj.weight',
|
||||
're:.*up_proj.weight',
|
||||
're:.*down_proj.weight',
|
||||
]
|
||||
start: 0
|
||||
|
||||
llmcompressor:
|
||||
recipe:
|
||||
finetuning_stage:
|
||||
finetuning_modifiers:
|
||||
ConstantPruningModifier:
|
||||
targets: [
|
||||
're:.*q_proj.weight',
|
||||
're:.*k_proj.weight',
|
||||
're:.*v_proj.weight',
|
||||
're:.*o_proj.weight',
|
||||
're:.*gate_proj.weight',
|
||||
're:.*up_proj.weight',
|
||||
're:.*down_proj.weight',
|
||||
]
|
||||
start: 0
|
||||
@@ -1,130 +1,146 @@
|
||||
"""
|
||||
Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks
|
||||
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
|
||||
by maintaining masks for zero weights during training.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from ..base import BasePlugin
|
||||
from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from llmcompressor import initialize
|
||||
from llmcompressor import active_session
|
||||
from llmcompressor.core import callbacks as session_callbacks
|
||||
from llmcompressor.recipe import Recipe
|
||||
|
||||
P = ParamSpec("P") # Params for generic function signatures
|
||||
R = TypeVar("R") # Return type for generic function signatures
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
|
||||
|
||||
|
||||
class SFTCallbackHandler(TrainerCallback):
|
||||
"""
|
||||
Transformer trainer callback for Sparse Finetuning.
|
||||
Maintains sparsity patterns during training by applying masks after optimization steps.
|
||||
This ensures that optimizer updates to zero weights are canceled out.
|
||||
Trainer callback for Sparse Finetuning.
|
||||
Maintains sparsity patterns during training by applying masks after optimization steps,
|
||||
ensuring zero-weight updates are canceled out.
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: object, recipe: object):
|
||||
def __init__(self, trainer: Trainer, recipe: Any):
|
||||
"""
|
||||
Initialize the callback handler.
|
||||
|
||||
Initialize the Sparse Finetuning callback handler.
|
||||
|
||||
Args:
|
||||
trainer (object): The trainer instance.
|
||||
recipe (object): The sparse finetuning recipe to be applied.
|
||||
trainer (Trainer): Huggingface Trainer instance.
|
||||
recipe (Recipe | dict): Sparse finetuning recipe to apply.
|
||||
"""
|
||||
super().__init__()
|
||||
self.trainer = trainer
|
||||
self.recipe = Recipe.model_validate(recipe)
|
||||
self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||
|
||||
if hasattr(self.trainer, "compute_loss"):
|
||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_train_begin(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Event triggered at the beginning of training.
|
||||
Updates the session reference to the model, accommodating changes due to wrappers like FSDP.
|
||||
Called at the beginning of training. Initializes the compression session.
|
||||
|
||||
Args:
|
||||
args (TrainingArguments): Training arguments.
|
||||
state (TrainerState): Trainer state.
|
||||
control (TrainerControl): Trainer control.
|
||||
"""
|
||||
super().on_train_begin(args, state, control, **kwargs)
|
||||
initialize(
|
||||
session = active_session()
|
||||
session.initialize(
|
||||
model=self.trainer.model,
|
||||
optimizer=self.trainer.optimizer,
|
||||
start=state.epoch,
|
||||
recipe=self.recipe,
|
||||
)
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_step_begin(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Event triggered at the beginning of a training step.
|
||||
Calls batch_start in the active CompressionSession.
|
||||
Called at the beginning of a training step. Triggers batch_start callback.
|
||||
"""
|
||||
super().on_step_begin(args, state, control, **kwargs)
|
||||
session_callbacks.batch_start()
|
||||
|
||||
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_step_end(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Event triggered at the end of a training step.
|
||||
Calls optimizer pre-step, post-step, and batch_end callbacks.
|
||||
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
|
||||
"""
|
||||
super().on_step_end(args, state, control, **kwargs)
|
||||
session_callbacks.optim_pre_step()
|
||||
session_callbacks.optim_post_step()
|
||||
session_callbacks.batch_end()
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_train_end(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Event triggered at the end of a substep during gradient accumulation.
|
||||
Calls batch_end in the active CompressionSession.
|
||||
Called at the end of training. Finalizes the compression session.
|
||||
"""
|
||||
super().on_substep_end(args, state, control, **kwargs)
|
||||
session_callbacks.batch_end()
|
||||
|
||||
# def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
# super().on_prediction_step(args, state, control, **kwargs)
|
||||
# session_callbacks.loss_calculated()
|
||||
super().on_train_end(args, state, control, **kwargs)
|
||||
session = active_session()
|
||||
session.finalize()
|
||||
|
||||
|
||||
class SFTPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for Sparse Finetuning integration with Axolotl.
|
||||
Sparse Finetuning plugin for Axolotl integration.
|
||||
"""
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
"""
|
||||
Returns the input argument path for the plugin.
|
||||
"""
|
||||
return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs"
|
||||
Returns the path to the plugin's argument definition.
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
"""
|
||||
Adds Sparse Finetuning callback to the trainer.
|
||||
|
||||
Args:
|
||||
cfg (object): Configuration object containing the recipe.
|
||||
trainer (object): Trainer instance to which the callback is added.
|
||||
|
||||
Returns:
|
||||
list: A list containing the Sparse Finetuning callback.
|
||||
str: Dotted path to the LLMCompressorArgs class.
|
||||
"""
|
||||
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs"
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
||||
"""
|
||||
Adds Sparse Finetuning callback to the Trainer instance.
|
||||
|
||||
Args:
|
||||
cfg (Any): Configuration object containing the sparse recipe.
|
||||
trainer (Trainer): Huggingface Trainer instance.
|
||||
|
||||
Returns:
|
||||
list: List containing the configured callback instances.
|
||||
"""
|
||||
LOG.info("Adding Sparse Finetuning callback to the trainer")
|
||||
callback = SFTCallbackHandler(
|
||||
trainer=trainer,
|
||||
recipe=cfg.recipe,
|
||||
recipe=cfg.llmcompressor.recipe,
|
||||
)
|
||||
return [callback]
|
||||
|
||||
|
||||
def compute_loss_wrapper(compute_loss_func):
|
||||
def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Wraps the loss computation function to integrate with the active CompressionSession.
|
||||
|
||||
Wraps the loss computation function to trigger the loss_calculated callback.
|
||||
|
||||
Args:
|
||||
compute_loss_func (function): The original loss computation function.
|
||||
|
||||
compute_loss_func (Callable): Original loss computation function.
|
||||
|
||||
Returns:
|
||||
function: Wrapped function that reports the computed loss.
|
||||
Callable: Wrapped function that also invokes the loss_calculated callback.
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
@wraps(compute_loss_func)
|
||||
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
loss = compute_loss_func(*args, **kwargs)
|
||||
session_callbacks.loss_calculated(loss=loss)
|
||||
# take the mean across multiple GPUs
|
||||
# this is done outside the compute_loss function in the parent
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
return wrapper
|
||||
|
||||
return compute_and_notify
|
||||
|
||||
@@ -1,13 +1,35 @@
|
||||
"""
|
||||
Pydantic model for accepting `llmcompressor` specific arguments.
|
||||
LLMCompressor and Sparse Finetuning config models.
|
||||
"""
|
||||
from typing import Optional, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Any
|
||||
from typing_extensions import Annotated
|
||||
|
||||
class SFTArgs(BaseModel):
|
||||
"""Sparse Finetuning config for LLMCompressor."""
|
||||
|
||||
# Typing for recipe is set to Any due to:
|
||||
# https://github.com/vllm-project/llm-compressor/issues/1319
|
||||
recipe: Annotated[
|
||||
Any,
|
||||
Field(description="Recipe config.")
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
|
||||
class LLMCompressorArgs(BaseModel):
|
||||
"""
|
||||
Input arguments for Sparse Finetuning.
|
||||
"""
|
||||
|
||||
recipe: Optional[Any] = None
|
||||
"""LLMCompressor configuration BaseModel."""
|
||||
|
||||
llmcompressor: Annotated[
|
||||
SFTArgs,
|
||||
Field(description="SFT llmcompressor args")
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
@@ -140,22 +140,20 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
and model_config.quantization_config
|
||||
)
|
||||
|
||||
# TODO: Use a better fix to handle
|
||||
# config.json produced by compressed-tensors
|
||||
# sparse-only model -> will also have a quantization_config
|
||||
|
||||
is_sparse_only_quant_config = bool(
|
||||
not quant_config_exists
|
||||
or (
|
||||
quant_config_exists
|
||||
and model_config.quantization_config["quant_method"] == "compressed-tensors"
|
||||
and not model_config.quantization_config.get("config_groups", False)
|
||||
and model_config.quantization_config.get("sparsity_config", False)
|
||||
)
|
||||
# Detect compressed-tensors config
|
||||
is_compressed_tensors_config = (
|
||||
quant_config_exists
|
||||
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
|
||||
)
|
||||
|
||||
if is_sparse_only_quant_config:
|
||||
quant_config_exists = False
|
||||
if is_compressed_tensors_config:
|
||||
if model_config.quantization_config.get("config_groups"):
|
||||
LOG.warn(
|
||||
"Found `config_groups` in a compressed-tensors config. "
|
||||
"QAT integration with llmcompressor is not tested."
|
||||
)
|
||||
# Skip further quant checks for compressed-tensors
|
||||
return
|
||||
|
||||
quant_config_method_is_gptq = (
|
||||
quant_config_exists
|
||||
|
||||
Reference in New Issue
Block a user