From 12cd09e6f58d128b83532d947397ce6afd8cf859 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 17 Apr 2025 17:19:59 -0400 Subject: [PATCH] Rebase and updates! --- examples/llama-3/sparse-finetuning.yaml | 1 + setup.py | 3 ++- src/axolotl/integrations/llm_compressor/args.py | 10 +++++++--- src/axolotl/integrations/llm_compressor/plugin.py | 14 ++++++++------ src/axolotl/integrations/llm_compressor/utils.py | 15 +++++++++++++++ src/axolotl/train.py | 14 ++++++++++++++ 6 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 src/axolotl/integrations/llm_compressor/utils.py diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml index 15c1f3e7f..1bbb88028 100644 --- a/examples/llama-3/sparse-finetuning.yaml +++ b/examples/llama-3/sparse-finetuning.yaml @@ -74,3 +74,4 @@ llmcompressor: 're:.*down_proj.weight', ] start: 0 + save_compressed: true diff --git a/setup.py b/setup.py index 88a8968ed..e2433d394 100644 --- a/setup.py +++ b/setup.py @@ -149,8 +149,9 @@ extras_require = { "vllm": [ "vllm==0.7.2", ], + # PENDING: https://github.com/vllm-project/llm-compressor/pull/1352 "llmcompressor": [ - "llmcompressor~=0.5.0", + "llmcompressor==0.5.1", ], } diff --git a/src/axolotl/integrations/llm_compressor/args.py b/src/axolotl/integrations/llm_compressor/args.py index a6e115dfc..5ab62325f 100644 --- a/src/axolotl/integrations/llm_compressor/args.py +++ b/src/axolotl/integrations/llm_compressor/args.py @@ -20,9 +20,13 @@ class CompressionArgs(BaseModel): ), ] - model_config = ConfigDict( - validate_assignment=True, - ) + save_compressed: Annotated[ + bool, + Field( + default=False, + description="Whether to save the compressed model after training.", + ), + ] class LLMCompressorArgs(BaseModel): diff --git a/src/axolotl/integrations/llm_compressor/plugin.py b/src/axolotl/integrations/llm_compressor/plugin.py index d4797b7c2..45895e42d 100644 --- a/src/axolotl/integrations/llm_compressor/plugin.py +++ b/src/axolotl/integrations/llm_compressor/plugin.py @@ -7,7 +7,7 @@ import logging from functools import wraps from typing import Any, Callable, ParamSpec, TypeVar -from llmcompressor import active_session +from llmcompressor import active_session, create_session from llmcompressor.core import callbacks as session_callbacks from llmcompressor.recipe import Recipe from transformers.trainer import Trainer @@ -43,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback): Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe ) self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) + create_session() def on_train_begin( self, @@ -60,13 +61,14 @@ class LLMCompressorCallbackHandler(TrainerCallback): control (TrainerControl): Trainer control. """ super().on_train_begin(args, state, control, **kwargs) - session = active_session() - session.initialize( + self.trainer.accelerator.wait_for_everyone() + active_session().initialize( model=self.trainer.model, optimizer=self.trainer.optimizer, start=state.epoch, recipe=self.recipe, ) + self.trainer.accelerator.wait_for_everyone() def on_step_begin( self, @@ -107,8 +109,7 @@ class LLMCompressorCallbackHandler(TrainerCallback): Called at the end of training. Finalizes the compression session. """ super().on_train_end(args, state, control, **kwargs) - session = active_session() - session.finalize() + active_session().finalize() class LLMCompressorPlugin(BasePlugin): @@ -158,7 +159,8 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]: @wraps(compute_loss_func) def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R: loss = compute_loss_func(*args, **kwargs) - session_callbacks.loss_calculated(loss=loss) + if active_session().lifecycle.initialized_: + session_callbacks.loss_calculated(loss=loss) return loss return compute_and_notify diff --git a/src/axolotl/integrations/llm_compressor/utils.py b/src/axolotl/integrations/llm_compressor/utils.py new file mode 100644 index 000000000..945c0f3ac --- /dev/null +++ b/src/axolotl/integrations/llm_compressor/utils.py @@ -0,0 +1,15 @@ +from transformers import Trainer + +def save_compressed_model( + model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool +): + from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained + trainer.accelerator.wait_for_everyone() + if trainer.accelerator.is_main_process: + modify_save_pretrained(model) + model.save_pretrained( + output_dir, + safe_serialization=safe_serialization, + save_compressed=save_compressed, + skip_sparsity_compression_stats=not save_compressed, + ) \ No newline at end of file diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 7896239de..4c6d77144 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -288,6 +288,19 @@ def save_trained_model( os.remove(os.path.join(cfg.output_dir, "model.safetensors")) except FileNotFoundError: pass + elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor: + from axolotl.integrations.llm_compressor.utils import ( + save_compressed_model, + ) + + save_compressed_model( + model=model, + output_dir=cfg.output_dir, + trainer=trainer, + safe_serialization=safe_serialization, + save_compressed=cfg.llmcompressor.save_compressed, + ) + elif cfg.local_rank == 0: if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) @@ -296,6 +309,7 @@ def save_trained_model( trainer.model.save_pretrained( cfg.output_dir, safe_serialization=safe_serialization ) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)