pre commit hooks
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
|
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
|
||||||
# TODO: change to
|
# TODO: change to
|
||||||
# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
||||||
|
|
||||||
@@ -68,11 +68,11 @@ llmcompressor:
|
|||||||
ConstantPruningModifier:
|
ConstantPruningModifier:
|
||||||
targets: [
|
targets: [
|
||||||
're:.*q_proj.weight',
|
're:.*q_proj.weight',
|
||||||
're:.*k_proj.weight',
|
're:.*k_proj.weight',
|
||||||
're:.*v_proj.weight',
|
're:.*v_proj.weight',
|
||||||
're:.*o_proj.weight',
|
're:.*o_proj.weight',
|
||||||
're:.*gate_proj.weight',
|
're:.*gate_proj.weight',
|
||||||
're:.*up_proj.weight',
|
're:.*up_proj.weight',
|
||||||
're:.*down_proj.weight',
|
're:.*down_proj.weight',
|
||||||
]
|
]
|
||||||
start: 0
|
start: 0
|
||||||
|
|||||||
@@ -5,19 +5,19 @@ by maintaining masks for zero weights during training.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, TypeVar, ParamSpec, Any
|
from typing import Any, Callable, ParamSpec, TypeVar
|
||||||
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
|
|
||||||
from transformers.training_args import TrainingArguments
|
|
||||||
|
|
||||||
from ..base import BasePlugin
|
|
||||||
from llmcompressor import active_session
|
from llmcompressor import active_session
|
||||||
from llmcompressor.core import callbacks as session_callbacks
|
from llmcompressor.core import callbacks as session_callbacks
|
||||||
from llmcompressor.recipe import Recipe
|
from llmcompressor.recipe import Recipe
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
|
|
||||||
|
from ..base import BasePlugin
|
||||||
|
|
||||||
P = ParamSpec("P") # Params for generic function signatures
|
P = ParamSpec("P") # Params for generic function signatures
|
||||||
R = TypeVar("R") # Return type for generic function signatures
|
R = TypeVar("R") # Return type for generic function signatures
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
|
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
|
||||||
|
|
||||||
@@ -39,11 +39,17 @@ class SFTCallbackHandler(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
self.recipe = Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
self.recipe = (
|
||||||
|
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||||
|
)
|
||||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||||
|
|
||||||
def on_train_begin(
|
def on_train_begin(
|
||||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called at the beginning of training. Initializes the compression session.
|
Called at the beginning of training. Initializes the compression session.
|
||||||
@@ -63,7 +69,11 @@ class SFTCallbackHandler(TrainerCallback):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called at the beginning of a training step. Triggers batch_start callback.
|
Called at the beginning of a training step. Triggers batch_start callback.
|
||||||
@@ -72,7 +82,11 @@ class SFTCallbackHandler(TrainerCallback):
|
|||||||
session_callbacks.batch_start()
|
session_callbacks.batch_start()
|
||||||
|
|
||||||
def on_step_end(
|
def on_step_end(
|
||||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
|
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
|
||||||
@@ -83,7 +97,11 @@ class SFTCallbackHandler(TrainerCallback):
|
|||||||
session_callbacks.batch_end()
|
session_callbacks.batch_end()
|
||||||
|
|
||||||
def on_train_end(
|
def on_train_end(
|
||||||
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called at the end of training. Finalizes the compression session.
|
Called at the end of training. Finalizes the compression session.
|
||||||
|
|||||||
@@ -2,19 +2,18 @@
|
|||||||
LLMCompressor and Sparse Finetuning config models.
|
LLMCompressor and Sparse Finetuning config models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
class SFTArgs(BaseModel):
|
class SFTArgs(BaseModel):
|
||||||
"""Sparse Finetuning config for LLMCompressor."""
|
"""Sparse Finetuning config for LLMCompressor."""
|
||||||
|
|
||||||
# Typing for recipe is set to Any due to:
|
# Typing for recipe is set to Any due to:
|
||||||
# https://github.com/vllm-project/llm-compressor/issues/1319
|
# https://github.com/vllm-project/llm-compressor/issues/1319
|
||||||
recipe: Annotated[
|
recipe: Annotated[Any, Field(description="Recipe config.")]
|
||||||
Any,
|
|
||||||
Field(description="Recipe config.")
|
|
||||||
]
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
@@ -25,10 +24,7 @@ class SFTArgs(BaseModel):
|
|||||||
class LLMCompressorArgs(BaseModel):
|
class LLMCompressorArgs(BaseModel):
|
||||||
"""LLMCompressor configuration BaseModel."""
|
"""LLMCompressor configuration BaseModel."""
|
||||||
|
|
||||||
llmcompressor: Annotated[
|
llmcompressor: Annotated[SFTArgs, Field(description="SFT llmcompressor args")]
|
||||||
SFTArgs,
|
|
||||||
Field(description="SFT llmcompressor args")
|
|
||||||
]
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
|||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
and model_config.quantization_config
|
and model_config.quantization_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Detect compressed-tensors config
|
# Detect compressed-tensors config
|
||||||
is_compressed_tensors_config = (
|
is_compressed_tensors_config = (
|
||||||
quant_config_exists
|
quant_config_exists
|
||||||
@@ -156,7 +156,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
|||||||
)
|
)
|
||||||
# Skip further quant checks for compressed-tensors
|
# Skip further quant checks for compressed-tensors
|
||||||
return
|
return
|
||||||
|
|
||||||
quant_config_method_is_gptq = (
|
quant_config_method_is_gptq = (
|
||||||
quant_config_exists
|
quant_config_exists
|
||||||
and "quant_method" in model_config.quantization_config
|
and "quant_method" in model_config.quantization_config
|
||||||
|
|||||||
Reference in New Issue
Block a user