Rebase and updates!

This commit is contained in:
Rahul Tuli
2025-04-17 17:19:59 -04:00
parent 45b7293793
commit ff4904c8c4
6 changed files with 47 additions and 10 deletions

View File

@@ -74,3 +74,4 @@ llmcompressor:
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -149,8 +149,9 @@ extras_require = {
"vllm": [
"vllm==0.7.2",
],
# PENDING: https://github.com/vllm-project/llm-compressor/pull/1352
"llmcompressor": [
"llmcompressor~=0.5.0",
"llmcompressor==0.5.1",
],
}

View File

@@ -20,9 +20,13 @@ class CompressionArgs(BaseModel):
),
]
model_config = ConfigDict(
validate_assignment=True,
)
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):

View File

@@ -7,7 +7,7 @@ import logging
from functools import wraps
from typing import Any, Callable, ParamSpec, TypeVar
from llmcompressor import active_session
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from transformers.trainer import Trainer
@@ -43,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
@@ -60,13 +61,14 @@ class LLMCompressorCallbackHandler(TrainerCallback):
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
session = active_session()
session.initialize(
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
@@ -107,8 +109,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
session = active_session()
session.finalize()
active_session().finalize()
class LLMCompressorPlugin(BasePlugin):
@@ -158,7 +159,8 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
@wraps(compute_loss_func)
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(*args, **kwargs)
session_callbacks.loss_calculated(loss=loss)
if active_session().lifecycle.initialized_:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -0,0 +1,15 @@
from transformers import Trainer
def save_compressed_model(
model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool
):
from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -271,6 +271,19 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
@@ -279,6 +292,7 @@ def save_trained_model(
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)