Address review comments from @markurtz
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.llmcompressor_sft.SFTPlugin
|
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
|
||||||
|
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: false
|
load_in_4bit: false
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -149,9 +149,6 @@ extras_require = {
|
|||||||
"vllm": [
|
"vllm": [
|
||||||
"vllm==0.7.2",
|
"vllm==0.7.2",
|
||||||
],
|
],
|
||||||
"llmcompressor": [
|
|
||||||
"llm-compressor==0.5.0",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ from ..base import BasePlugin
|
|||||||
P = ParamSpec("P") # Params for generic function signatures
|
P = ParamSpec("P") # Params for generic function signatures
|
||||||
R = TypeVar("R") # Return type for generic function signatures
|
R = TypeVar("R") # Return type for generic function signatures
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
|
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
|
||||||
|
|
||||||
|
|
||||||
class SFTCallbackHandler(TrainerCallback):
|
class LLMCompressorCallbackHandler(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
Trainer callback for Sparse Finetuning.
|
Trainer callback for Sparse Finetuning.
|
||||||
Maintains sparsity patterns during training by applying masks after optimization steps,
|
Maintains sparsity patterns during training by applying masks after optimization steps,
|
||||||
@@ -111,7 +111,7 @@ class SFTCallbackHandler(TrainerCallback):
|
|||||||
session.finalize()
|
session.finalize()
|
||||||
|
|
||||||
|
|
||||||
class SFTPlugin(BasePlugin):
|
class LLMCompressorPlugin(BasePlugin):
|
||||||
"""
|
"""
|
||||||
Sparse Finetuning plugin for Axolotl integration.
|
Sparse Finetuning plugin for Axolotl integration.
|
||||||
"""
|
"""
|
||||||
@@ -123,7 +123,7 @@ class SFTPlugin(BasePlugin):
|
|||||||
Returns:
|
Returns:
|
||||||
str: Dotted path to the LLMCompressorArgs class.
|
str: Dotted path to the LLMCompressorArgs class.
|
||||||
"""
|
"""
|
||||||
return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs"
|
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
||||||
"""
|
"""
|
||||||
@@ -137,7 +137,7 @@ class SFTPlugin(BasePlugin):
|
|||||||
list: List containing the configured callback instances.
|
list: List containing the configured callback instances.
|
||||||
"""
|
"""
|
||||||
LOG.info("Adding Sparse Finetuning callback to the trainer")
|
LOG.info("Adding Sparse Finetuning callback to the trainer")
|
||||||
callback = SFTCallbackHandler(
|
callback = LLMCompressorCallbackHandler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
recipe=cfg.llmcompressor.recipe,
|
recipe=cfg.llmcompressor.recipe,
|
||||||
)
|
)
|
||||||
@@ -8,12 +8,17 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
class SFTArgs(BaseModel):
|
class CompressionArgs(BaseModel):
|
||||||
"""Sparse Finetuning config for LLMCompressor."""
|
"""Sparse Finetuning config for LLMCompressor."""
|
||||||
|
|
||||||
# Typing for recipe is set to Any due to:
|
# Typing for recipe is set to Any due to:
|
||||||
# https://github.com/vllm-project/llm-compressor/issues/1319
|
# https://github.com/vllm-project/llm-compressor/issues/1319
|
||||||
recipe: Annotated[Any, Field(description="The recipe containing the compression algorithms and hyperparameters to apply.")]
|
recipe: Annotated[
|
||||||
|
Any,
|
||||||
|
Field(
|
||||||
|
description="The recipe containing the compression algorithms and hyperparameters to apply."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
@@ -24,7 +29,12 @@ class SFTArgs(BaseModel):
|
|||||||
class LLMCompressorArgs(BaseModel):
|
class LLMCompressorArgs(BaseModel):
|
||||||
"""LLMCompressor configuration BaseModel."""
|
"""LLMCompressor configuration BaseModel."""
|
||||||
|
|
||||||
llmcompressor: Annotated[SFTArgs, Field(description="Arguments enabling compression pathways through the LLM Compressor plugins")]
|
llmcompressor: Annotated[
|
||||||
|
CompressionArgs,
|
||||||
|
Field(
|
||||||
|
description="Arguments enabling compression pathways through the LLM Compressor plugins"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
Reference in New Issue
Block a user