Address review comments from @markurtz

This commit is contained in:
Rahul Tuli
2025-04-04 17:59:41 +00:00
committed by Wing Lian
parent 8855bb115f
commit 83a88b745f
4 changed files with 19 additions and 12 deletions

View File

@@ -1,7 +1,7 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4 base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins: plugins:
- axolotl.integrations.llmcompressor_sft.SFTPlugin - axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false

View File

@@ -149,9 +149,6 @@ extras_require = {
"vllm": [ "vllm": [
"vllm==0.7.2", "vllm==0.7.2",
], ],
"llmcompressor": [
"llm-compressor==0.5.0",
],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -19,10 +19,10 @@ from ..base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft") LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class SFTCallbackHandler(TrainerCallback): class LLMCompressorCallbackHandler(TrainerCallback):
""" """
Trainer callback for Sparse Finetuning. Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps, Maintains sparsity patterns during training by applying masks after optimization steps,
@@ -111,7 +111,7 @@ class SFTCallbackHandler(TrainerCallback):
session.finalize() session.finalize()
class SFTPlugin(BasePlugin): class LLMCompressorPlugin(BasePlugin):
""" """
Sparse Finetuning plugin for Axolotl integration. Sparse Finetuning plugin for Axolotl integration.
""" """
@@ -123,7 +123,7 @@ class SFTPlugin(BasePlugin):
Returns: Returns:
str: Dotted path to the LLMCompressorArgs class. str: Dotted path to the LLMCompressorArgs class.
""" """
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs" return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
""" """
@@ -137,7 +137,7 @@ class SFTPlugin(BasePlugin):
list: List containing the configured callback instances. list: List containing the configured callback instances.
""" """
LOG.info("Adding Sparse Finetuning callback to the trainer") LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = SFTCallbackHandler( callback = LLMCompressorCallbackHandler(
trainer=trainer, trainer=trainer,
recipe=cfg.llmcompressor.recipe, recipe=cfg.llmcompressor.recipe,
) )

View File

@@ -8,12 +8,17 @@ from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
class SFTArgs(BaseModel): class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor.""" """Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to: # Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319 # https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[Any, Field(description="The recipe containing the compression algorithms and hyperparameters to apply.")] recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
@@ -24,7 +29,12 @@ class SFTArgs(BaseModel):
class LLMCompressorArgs(BaseModel): class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel.""" """LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[SFTArgs, Field(description="Arguments enabling compression pathways through the LLM Compressor plugins")] llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]
model_config = ConfigDict( model_config = ConfigDict(
validate_assignment=True, validate_assignment=True,