pre commit hooks
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
|
||||
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
|
||||
# TODO: change to
|
||||
# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
||||
|
||||
@@ -68,11 +68,11 @@ llmcompressor:
|
||||
ConstantPruningModifier:
|
||||
targets: [
|
||||
're:.*q_proj.weight',
|
||||
're:.*k_proj.weight',
|
||||
're:.*k_proj.weight',
|
||||
're:.*v_proj.weight',
|
||||
're:.*o_proj.weight',
|
||||
're:.*gate_proj.weight',
|
||||
're:.*up_proj.weight',
|
||||
're:.*down_proj.weight',
|
||||
]
|
||||
start: 0
|
||||
start: 0
|
||||
|
||||
@@ -5,19 +5,19 @@ by maintaining masks for zero weights during training.
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
from typing import Any, Callable, ParamSpec, TypeVar
|
||||
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from ..base import BasePlugin
|
||||
from llmcompressor import active_session
|
||||
from llmcompressor.core import callbacks as session_callbacks
|
||||
from llmcompressor.recipe import Recipe
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from ..base import BasePlugin
|
||||
|
||||
P = ParamSpec("P") # Params for generic function signatures
|
||||
R = TypeVar("R") # Return type for generic function signatures
|
||||
R = TypeVar("R") # Return type for generic function signatures
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
|
||||
|
||||
@@ -39,11 +39,17 @@ class SFTCallbackHandler(TrainerCallback):
|
||||
"""
|
||||
super().__init__()
|
||||
self.trainer = trainer
|
||||
self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||
self.recipe = (
|
||||
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||
)
|
||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||
|
||||
def on_train_begin(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Called at the beginning of training. Initializes the compression session.
|
||||
@@ -63,7 +69,11 @@ class SFTCallbackHandler(TrainerCallback):
|
||||
)
|
||||
|
||||
def on_step_begin(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Called at the beginning of a training step. Triggers batch_start callback.
|
||||
@@ -72,7 +82,11 @@ class SFTCallbackHandler(TrainerCallback):
|
||||
session_callbacks.batch_start()
|
||||
|
||||
def on_step_end(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
|
||||
@@ -83,7 +97,11 @@ class SFTCallbackHandler(TrainerCallback):
|
||||
session_callbacks.batch_end()
|
||||
|
||||
def on_train_end(
|
||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Called at the end of training. Finalizes the compression session.
|
||||
|
||||
@@ -2,19 +2,18 @@
|
||||
LLMCompressor and Sparse Finetuning config models.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class SFTArgs(BaseModel):
|
||||
"""Sparse Finetuning config for LLMCompressor."""
|
||||
|
||||
# Typing for recipe is set to Any due to:
|
||||
# https://github.com/vllm-project/llm-compressor/issues/1319
|
||||
recipe: Annotated[
|
||||
Any,
|
||||
Field(description="Recipe config.")
|
||||
]
|
||||
recipe: Annotated[Any, Field(description="Recipe config.")]
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
@@ -25,10 +24,7 @@ class SFTArgs(BaseModel):
|
||||
class LLMCompressorArgs(BaseModel):
|
||||
"""LLMCompressor configuration BaseModel."""
|
||||
|
||||
llmcompressor: Annotated[
|
||||
SFTArgs,
|
||||
Field(description="SFT llmcompressor args")
|
||||
]
|
||||
llmcompressor: Annotated[SFTArgs, Field(description="SFT llmcompressor args")]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
|
||||
@@ -141,7 +141,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
hasattr(model_config, "quantization_config")
|
||||
and model_config.quantization_config
|
||||
)
|
||||
|
||||
|
||||
# Detect compressed-tensors config
|
||||
is_compressed_tensors_config = (
|
||||
quant_config_exists
|
||||
@@ -156,7 +156,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
)
|
||||
# Skip further quant checks for compressed-tensors
|
||||
return
|
||||
|
||||
|
||||
quant_config_method_is_gptq = (
|
||||
quant_config_exists
|
||||
and "quant_method" in model_config.quantization_config
|
||||
|
||||
Reference in New Issue
Block a user