From 83a88b745f050173465ef06f9ece0503a3686e10 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 4 Apr 2025 17:59:41 +0000 Subject: [PATCH] Address review comments from @markurtz --- examples/llama-3/sft.yaml | 2 +- setup.py | 3 --- .../__init__.py | 10 +++++----- .../args.py | 16 +++++++++++++--- 4 files changed, 19 insertions(+), 12 deletions(-) rename src/axolotl/integrations/{llmcompressor_sft => llm_compressor}/__init__.py (94%) rename src/axolotl/integrations/{llmcompressor_sft => llm_compressor}/args.py (60%) diff --git a/examples/llama-3/sft.yaml b/examples/llama-3/sft.yaml index e688e1165..15c1f3e7f 100644 --- a/examples/llama-3/sft.yaml +++ b/examples/llama-3/sft.yaml @@ -1,7 +1,7 @@ base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4 plugins: - - axolotl.integrations.llmcompressor_sft.SFTPlugin + - axolotl.integrations.llm_compressor.LLMCompressorPlugin load_in_8bit: false load_in_4bit: false diff --git a/setup.py b/setup.py index 291ffeaa3..5b66a2ea7 100644 --- a/setup.py +++ b/setup.py @@ -149,9 +149,6 @@ extras_require = { "vllm": [ "vllm==0.7.2", ], - "llmcompressor": [ - "llm-compressor==0.5.0", - ], } install_requires, dependency_links, extras_require_build = parse_requirements( diff --git a/src/axolotl/integrations/llmcompressor_sft/__init__.py b/src/axolotl/integrations/llm_compressor/__init__.py similarity index 94% rename from src/axolotl/integrations/llmcompressor_sft/__init__.py rename to src/axolotl/integrations/llm_compressor/__init__.py index 585756185..4ed55ead6 100644 --- a/src/axolotl/integrations/llmcompressor_sft/__init__.py +++ b/src/axolotl/integrations/llm_compressor/__init__.py @@ -19,10 +19,10 @@ from ..base import BasePlugin P = ParamSpec("P") # Params for generic function signatures R = TypeVar("R") # Return type for generic function signatures -LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft") +LOG = logging.getLogger("axolotl.integrations.llm_compressor") -class SFTCallbackHandler(TrainerCallback): +class LLMCompressorCallbackHandler(TrainerCallback): """ Trainer callback for Sparse Finetuning. Maintains sparsity patterns during training by applying masks after optimization steps, @@ -111,7 +111,7 @@ class SFTCallbackHandler(TrainerCallback): session.finalize() -class SFTPlugin(BasePlugin): +class LLMCompressorPlugin(BasePlugin): """ Sparse Finetuning plugin for Axolotl integration. """ @@ -123,7 +123,7 @@ class SFTPlugin(BasePlugin): Returns: str: Dotted path to the LLMCompressorArgs class. """ - return "axolotl.integrations.llmcompressor_sft.args.LLMCompressorArgs" + return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs" def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: """ @@ -137,7 +137,7 @@ class SFTPlugin(BasePlugin): list: List containing the configured callback instances. """ LOG.info("Adding Sparse Finetuning callback to the trainer") - callback = SFTCallbackHandler( + callback = LLMCompressorCallbackHandler( trainer=trainer, recipe=cfg.llmcompressor.recipe, ) diff --git a/src/axolotl/integrations/llmcompressor_sft/args.py b/src/axolotl/integrations/llm_compressor/args.py similarity index 60% rename from src/axolotl/integrations/llmcompressor_sft/args.py rename to src/axolotl/integrations/llm_compressor/args.py index 0faf5b91d..3b1ab34ba 100644 --- a/src/axolotl/integrations/llmcompressor_sft/args.py +++ b/src/axolotl/integrations/llm_compressor/args.py @@ -8,12 +8,17 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -class SFTArgs(BaseModel): +class CompressionArgs(BaseModel): """Sparse Finetuning config for LLMCompressor.""" # Typing for recipe is set to Any due to: # https://github.com/vllm-project/llm-compressor/issues/1319 - recipe: Annotated[Any, Field(description="The recipe containing the compression algorithms and hyperparameters to apply.")] + recipe: Annotated[ + Any, + Field( + description="The recipe containing the compression algorithms and hyperparameters to apply." + ), + ] model_config = ConfigDict( arbitrary_types_allowed=True, @@ -24,7 +29,12 @@ class SFTArgs(BaseModel): class LLMCompressorArgs(BaseModel): """LLMCompressor configuration BaseModel.""" - llmcompressor: Annotated[SFTArgs, Field(description="Arguments enabling compression pathways through the LLM Compressor plugins")] + llmcompressor: Annotated[ + CompressionArgs, + Field( + description="Arguments enabling compression pathways through the LLM Compressor plugins" + ), + ] model_config = ConfigDict( validate_assignment=True,