pre commit hooks

This commit is contained in:
Rahul Tuli
2025-04-02 22:35:36 +00:00
committed by Wing Lian
parent 56ba66b60f
commit 8beb2f27ad
4 changed files with 40 additions and 26 deletions

View File

@@ -5,19 +5,19 @@ by maintaining masks for zero weights during training.
import logging import logging
from functools import wraps from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any from typing import Any, Callable, ParamSpec, TypeVar
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
from transformers.training_args import TrainingArguments
from ..base import BasePlugin
from llmcompressor import active_session from llmcompressor import active_session
from llmcompressor.core import callbacks as session_callbacks from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe from llmcompressor.recipe import Recipe
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from ..base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft") LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
@@ -39,11 +39,17 @@ class SFTCallbackHandler(TrainerCallback):
""" """
super().__init__() super().__init__()
self.trainer = trainer self.trainer = trainer
self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
def on_train_begin( def on_train_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None: ) -> None:
""" """
Called at the beginning of training. Initializes the compression session. Called at the beginning of training. Initializes the compression session.
@@ -63,7 +69,11 @@ class SFTCallbackHandler(TrainerCallback):
) )
def on_step_begin( def on_step_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None: ) -> None:
""" """
Called at the beginning of a training step. Triggers batch_start callback. Called at the beginning of a training step. Triggers batch_start callback.
@@ -72,7 +82,11 @@ class SFTCallbackHandler(TrainerCallback):
session_callbacks.batch_start() session_callbacks.batch_start()
def on_step_end( def on_step_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None: ) -> None:
""" """
Called at the end of a training step. Triggers optimizer and batch_end callbacks. Called at the end of a training step. Triggers optimizer and batch_end callbacks.
@@ -83,7 +97,11 @@ class SFTCallbackHandler(TrainerCallback):
session_callbacks.batch_end() session_callbacks.batch_end()
def on_train_end( def on_train_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None: ) -> None:
""" """
Called at the end of training. Finalizes the compression session. Called at the end of training. Finalizes the compression session.

View File

@@ -2,19 +2,18 @@
LLMCompressor and Sparse Finetuning config models. LLMCompressor and Sparse Finetuning config models.
""" """
from pydantic import BaseModel, Field, ConfigDict
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
class SFTArgs(BaseModel): class SFTArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor.""" """Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to: # Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319 # https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[ recipe: Annotated[Any, Field(description="Recipe config.")]
Any,
Field(description="Recipe config.")
]
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
@@ -25,10 +24,7 @@ class SFTArgs(BaseModel):
class LLMCompressorArgs(BaseModel): class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel.""" """LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[ llmcompressor: Annotated[SFTArgs, Field(description="SFT llmcompressor args")]
SFTArgs,
Field(description="SFT llmcompressor args")
]
model_config = ConfigDict( model_config = ConfigDict(
validate_assignment=True, validate_assignment=True,