pre commit hooks

This commit is contained in:
Rahul Tuli
2025-04-02 22:35:36 +00:00
committed by Wing Lian
parent 56ba66b60f
commit 8beb2f27ad
4 changed files with 40 additions and 26 deletions

View File

@@ -1,4 +1,4 @@
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
# TODO: change to
# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
@@ -68,11 +68,11 @@ llmcompressor:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
start: 0

View File

@@ -5,19 +5,19 @@ by maintaining masks for zero weights during training.
import logging
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any
from typing import Any, Callable, ParamSpec, TypeVar
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
from transformers.training_args import TrainingArguments
from ..base import BasePlugin
from llmcompressor import active_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from ..base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
@@ -39,11 +39,17 @@ class SFTCallbackHandler(TrainerCallback):
"""
super().__init__()
self.trainer = trainer
self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
def on_train_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
@@ -63,7 +69,11 @@ class SFTCallbackHandler(TrainerCallback):
)
def on_step_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
@@ -72,7 +82,11 @@ class SFTCallbackHandler(TrainerCallback):
session_callbacks.batch_start()
def on_step_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
@@ -83,7 +97,11 @@ class SFTCallbackHandler(TrainerCallback):
session_callbacks.batch_end()
def on_train_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.

View File

@@ -2,19 +2,18 @@
LLMCompressor and Sparse Finetuning config models.
"""
from pydantic import BaseModel, Field, ConfigDict
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
class SFTArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(description="Recipe config.")
]
recipe: Annotated[Any, Field(description="Recipe config.")]
model_config = ConfigDict(
arbitrary_types_allowed=True,
@@ -25,10 +24,7 @@ class SFTArgs(BaseModel):
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
SFTArgs,
Field(description="SFT llmcompressor args")
]
llmcompressor: Annotated[SFTArgs, Field(description="SFT llmcompressor args")]
model_config = ConfigDict(
validate_assignment=True,

View File

@@ -141,7 +141,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
@@ -156,7 +156,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config