From 8beb2f27adf64c3fb63fe14b3ee9b97b9efe5058 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 2 Apr 2025 22:35:36 +0000 Subject: [PATCH] pre commit hooks --- examples/llama-3/sft.yaml | 6 +-- .../llmcompressor_sft/__init__.py | 42 +++++++++++++------ .../integrations/llmcompressor_sft/args.py | 14 +++---- src/axolotl/utils/models.py | 4 +- 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/examples/llama-3/sft.yaml b/examples/llama-3/sft.yaml index 0ec32626f..078c4873f 100644 --- a/examples/llama-3/sft.yaml +++ b/examples/llama-3/sft.yaml @@ -1,4 +1,4 @@ -base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed" +base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed" # TODO: change to # base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4 @@ -68,11 +68,11 @@ llmcompressor: ConstantPruningModifier: targets: [ 're:.*q_proj.weight', - 're:.*k_proj.weight', + 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight', 're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight', ] - start: 0 \ No newline at end of file + start: 0 diff --git a/src/axolotl/integrations/llmcompressor_sft/__init__.py b/src/axolotl/integrations/llmcompressor_sft/__init__.py index b99daef47..585756185 100644 --- a/src/axolotl/integrations/llmcompressor_sft/__init__.py +++ b/src/axolotl/integrations/llmcompressor_sft/__init__.py @@ -5,19 +5,19 @@ by maintaining masks for zero weights during training. import logging from functools import wraps -from typing import Callable, TypeVar, ParamSpec, Any +from typing import Any, Callable, ParamSpec, TypeVar -from transformers.trainer import Trainer -from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl -from transformers.training_args import TrainingArguments - -from ..base import BasePlugin from llmcompressor import active_session from llmcompressor.core import callbacks as session_callbacks from llmcompressor.recipe import Recipe +from transformers.trainer import Trainer +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from ..base import BasePlugin P = ParamSpec("P") # Params for generic function signatures -R = TypeVar("R") # Return type for generic function signatures +R = TypeVar("R") # Return type for generic function signatures LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft") @@ -39,11 +39,17 @@ class SFTCallbackHandler(TrainerCallback): """ super().__init__() self.trainer = trainer - self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe + self.recipe = ( + Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe + ) self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) def on_train_begin( - self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, ) -> None: """ Called at the beginning of training. Initializes the compression session. @@ -63,7 +69,11 @@ class SFTCallbackHandler(TrainerCallback): ) def on_step_begin( - self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, ) -> None: """ Called at the beginning of a training step. Triggers batch_start callback. @@ -72,7 +82,11 @@ class SFTCallbackHandler(TrainerCallback): session_callbacks.batch_start() def on_step_end( - self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, ) -> None: """ Called at the end of a training step. Triggers optimizer and batch_end callbacks. @@ -83,7 +97,11 @@ class SFTCallbackHandler(TrainerCallback): session_callbacks.batch_end() def on_train_end( - self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, ) -> None: """ Called at the end of training. Finalizes the compression session. diff --git a/src/axolotl/integrations/llmcompressor_sft/args.py b/src/axolotl/integrations/llmcompressor_sft/args.py index 24106bc32..fe7c1555d 100644 --- a/src/axolotl/integrations/llmcompressor_sft/args.py +++ b/src/axolotl/integrations/llmcompressor_sft/args.py @@ -2,19 +2,18 @@ LLMCompressor and Sparse Finetuning config models. """ -from pydantic import BaseModel, Field, ConfigDict from typing import Any + +from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated + class SFTArgs(BaseModel): """Sparse Finetuning config for LLMCompressor.""" # Typing for recipe is set to Any due to: # https://github.com/vllm-project/llm-compressor/issues/1319 - recipe: Annotated[ - Any, - Field(description="Recipe config.") - ] + recipe: Annotated[Any, Field(description="Recipe config.")] model_config = ConfigDict( arbitrary_types_allowed=True, @@ -25,10 +24,7 @@ class SFTArgs(BaseModel): class LLMCompressorArgs(BaseModel): """LLMCompressor configuration BaseModel.""" - llmcompressor: Annotated[ - SFTArgs, - Field(description="SFT llmcompressor args") - ] + llmcompressor: Annotated[SFTArgs, Field(description="SFT llmcompressor args")] model_config = ConfigDict( validate_assignment=True, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index cb9b316aa..61823f369 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -141,7 +141,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): hasattr(model_config, "quantization_config") and model_config.quantization_config ) - + # Detect compressed-tensors config is_compressed_tensors_config = ( quant_config_exists @@ -156,7 +156,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): ) # Skip further quant checks for compressed-tensors return - + quant_config_method_is_gptq = ( quant_config_exists and "quant_method" in model_config.quantization_config