Address review comments from @markurtz

This commit is contained in:
Rahul Tuli
2025-04-04 17:59:41 +00:00
committed by Wing Lian
parent 8855bb115f
commit 83a88b745f
4 changed files with 19 additions and 12 deletions

View File

@@ -1,7 +1,7 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llmcompressor_sft.SFTPlugin
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false

View File

@@ -149,9 +149,6 @@ extras_require = {
"vllm": [
"vllm==0.7.2",
],
"llmcompressor": [
"llm-compressor==0.5.0",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -19,10 +19,10 @@ from ..base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class SFTCallbackHandler(TrainerCallback):
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
@@ -111,7 +111,7 @@ class SFTCallbackHandler(TrainerCallback):
session.finalize()
class SFTPlugin(BasePlugin):
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
@@ -123,7 +123,7 @@ class SFTPlugin(BasePlugin):
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs"
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
@@ -137,7 +137,7 @@ class SFTPlugin(BasePlugin):
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = SFTCallbackHandler(
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)

View File

@@ -8,12 +8,17 @@ from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
class SFTArgs(BaseModel):
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[Any, Field(description="The recipe containing the compression algorithms and hyperparameters to apply.")]
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
model_config = ConfigDict(
arbitrary_types_allowed=True,
@@ -24,7 +29,12 @@ class SFTArgs(BaseModel):
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[SFTArgs, Field(description="Arguments enabling compression pathways through the LLM Compressor plugins")]
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]
model_config = ConfigDict(
validate_assignment=True,