Address review comments from @markurtz
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.llmcompressor_sft.SFTPlugin
|
||||
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
3
setup.py
3
setup.py
@@ -149,9 +149,6 @@ extras_require = {
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
"llmcompressor": [
|
||||
"llm-compressor==0.5.0",
|
||||
],
|
||||
}
|
||||
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
|
||||
@@ -19,10 +19,10 @@ from ..base import BasePlugin
|
||||
P = ParamSpec("P") # Params for generic function signatures
|
||||
R = TypeVar("R") # Return type for generic function signatures
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
|
||||
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
|
||||
|
||||
|
||||
class SFTCallbackHandler(TrainerCallback):
|
||||
class LLMCompressorCallbackHandler(TrainerCallback):
|
||||
"""
|
||||
Trainer callback for Sparse Finetuning.
|
||||
Maintains sparsity patterns during training by applying masks after optimization steps,
|
||||
@@ -111,7 +111,7 @@ class SFTCallbackHandler(TrainerCallback):
|
||||
session.finalize()
|
||||
|
||||
|
||||
class SFTPlugin(BasePlugin):
|
||||
class LLMCompressorPlugin(BasePlugin):
|
||||
"""
|
||||
Sparse Finetuning plugin for Axolotl integration.
|
||||
"""
|
||||
@@ -123,7 +123,7 @@ class SFTPlugin(BasePlugin):
|
||||
Returns:
|
||||
str: Dotted path to the LLMCompressorArgs class.
|
||||
"""
|
||||
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs"
|
||||
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
||||
"""
|
||||
@@ -137,7 +137,7 @@ class SFTPlugin(BasePlugin):
|
||||
list: List containing the configured callback instances.
|
||||
"""
|
||||
LOG.info("Adding Sparse Finetuning callback to the trainer")
|
||||
callback = SFTCallbackHandler(
|
||||
callback = LLMCompressorCallbackHandler(
|
||||
trainer=trainer,
|
||||
recipe=cfg.llmcompressor.recipe,
|
||||
)
|
||||
@@ -8,12 +8,17 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class SFTArgs(BaseModel):
|
||||
class CompressionArgs(BaseModel):
|
||||
"""Sparse Finetuning config for LLMCompressor."""
|
||||
|
||||
# Typing for recipe is set to Any due to:
|
||||
# https://github.com/vllm-project/llm-compressor/issues/1319
|
||||
recipe: Annotated[Any, Field(description="The recipe containing the compression algorithms and hyperparameters to apply.")]
|
||||
recipe: Annotated[
|
||||
Any,
|
||||
Field(
|
||||
description="The recipe containing the compression algorithms and hyperparameters to apply."
|
||||
),
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
@@ -24,7 +29,12 @@ class SFTArgs(BaseModel):
|
||||
class LLMCompressorArgs(BaseModel):
|
||||
"""LLMCompressor configuration BaseModel."""
|
||||
|
||||
llmcompressor: Annotated[SFTArgs, Field(description="Arguments enabling compression pathways through the LLM Compressor plugins")]
|
||||
llmcompressor: Annotated[
|
||||
CompressionArgs,
|
||||
Field(
|
||||
description="Arguments enabling compression pathways through the LLM Compressor plugins"
|
||||
),
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
Reference in New Issue
Block a user