Rebase and updates!

This commit is contained in:
Rahul Tuli
2025-04-17 17:19:59 -04:00
parent 45b7293793
commit ff4904c8c4
6 changed files with 47 additions and 10 deletions

View File

@@ -74,3 +74,4 @@ llmcompressor:
're:.*down_proj.weight', 're:.*down_proj.weight',
] ]
start: 0 start: 0
save_compressed: true

View File

@@ -149,8 +149,9 @@ extras_require = {
"vllm": [ "vllm": [
"vllm==0.7.2", "vllm==0.7.2",
], ],
# PENDING: https://github.com/vllm-project/llm-compressor/pull/1352
"llmcompressor": [ "llmcompressor": [
"llmcompressor~=0.5.0", "llmcompressor==0.5.1",
], ],
} }

View File

@@ -20,9 +20,13 @@ class CompressionArgs(BaseModel):
), ),
] ]
model_config = ConfigDict( save_compressed: Annotated[
validate_assignment=True, bool,
) Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel): class LLMCompressorArgs(BaseModel):

View File

@@ -7,7 +7,7 @@ import logging
from functools import wraps from functools import wraps
from typing import Any, Callable, ParamSpec, TypeVar from typing import Any, Callable, ParamSpec, TypeVar
from llmcompressor import active_session from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe from llmcompressor.recipe import Recipe
from transformers.trainer import Trainer from transformers.trainer import Trainer
@@ -43,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
) )
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin( def on_train_begin(
self, self,
@@ -60,13 +61,14 @@ class LLMCompressorCallbackHandler(TrainerCallback):
control (TrainerControl): Trainer control. control (TrainerControl): Trainer control.
""" """
super().on_train_begin(args, state, control, **kwargs) super().on_train_begin(args, state, control, **kwargs)
session = active_session() self.trainer.accelerator.wait_for_everyone()
session.initialize( active_session().initialize(
model=self.trainer.model, model=self.trainer.model,
optimizer=self.trainer.optimizer, optimizer=self.trainer.optimizer,
start=state.epoch, start=state.epoch,
recipe=self.recipe, recipe=self.recipe,
) )
self.trainer.accelerator.wait_for_everyone()
def on_step_begin( def on_step_begin(
self, self,
@@ -107,8 +109,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
Called at the end of training. Finalizes the compression session. Called at the end of training. Finalizes the compression session.
""" """
super().on_train_end(args, state, control, **kwargs) super().on_train_end(args, state, control, **kwargs)
session = active_session() active_session().finalize()
session.finalize()
class LLMCompressorPlugin(BasePlugin): class LLMCompressorPlugin(BasePlugin):
@@ -158,7 +159,8 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
@wraps(compute_loss_func) @wraps(compute_loss_func)
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R: def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(*args, **kwargs) loss = compute_loss_func(*args, **kwargs)
session_callbacks.loss_calculated(loss=loss) if active_session().lifecycle.initialized_:
session_callbacks.loss_calculated(loss=loss)
return loss return loss
return compute_and_notify return compute_and_notify

View File

@@ -0,0 +1,15 @@
from transformers import Trainer
def save_compressed_model(
model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool
):
from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -271,6 +271,19 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
@@ -279,6 +292,7 @@ def save_trained_model(
trainer.model.save_pretrained( trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
) )
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)