Tests, Style, Updates
This commit is contained in:
@@ -4,7 +4,7 @@ LLMCompressor and Sparse Finetuning config models.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@@ -38,7 +38,3 @@ class LLMCompressorArgs(BaseModel):
|
||||
description="Arguments enabling compression pathways through the LLM Compressor plugins"
|
||||
),
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
@@ -5,11 +5,12 @@ by maintaining masks for zero weights during training.
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, TypeVar
|
||||
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from llmcompressor import active_session, create_session
|
||||
from llmcompressor.core import callbacks as session_callbacks
|
||||
from llmcompressor.recipe import Recipe
|
||||
from torch.nn import Module
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
@@ -42,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
||||
self.recipe = (
|
||||
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||
)
|
||||
self.original_compute_loss = trainer.compute_loss
|
||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||
create_session()
|
||||
|
||||
@@ -110,6 +112,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
||||
"""
|
||||
super().on_train_end(args, state, control, **kwargs)
|
||||
active_session().finalize()
|
||||
self.trainer.compute_loss_func = self.original_compute_loss
|
||||
|
||||
|
||||
class LLMCompressorPlugin(BasePlugin):
|
||||
@@ -145,7 +148,9 @@ class LLMCompressorPlugin(BasePlugin):
|
||||
return [callback]
|
||||
|
||||
|
||||
def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def compute_loss_wrapper(
|
||||
compute_loss_func: Callable[Concatenate[Module, P], R],
|
||||
) -> Callable[Concatenate[Module, P], R]:
|
||||
"""
|
||||
Wraps the loss computation function to trigger the loss_calculated callback.
|
||||
|
||||
@@ -157,9 +162,9 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
|
||||
@wraps(compute_loss_func)
|
||||
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
loss = compute_loss_func(*args, **kwargs)
|
||||
if active_session().lifecycle.initialized_:
|
||||
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
loss = compute_loss_func(model, *args, **kwargs)
|
||||
if active_session().lifecycle.initialized_ and model.training:
|
||||
session_callbacks.loss_calculated(loss=loss)
|
||||
return loss
|
||||
|
||||
|
||||
@@ -1,15 +1,40 @@
|
||||
from transformers import Trainer
|
||||
"""Utilities for llmcompressor integration with axolotl."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
|
||||
modify_save_pretrained,
|
||||
)
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
|
||||
|
||||
def save_compressed_model(
|
||||
model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool
|
||||
):
|
||||
from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained
|
||||
model: PreTrainedModel,
|
||||
output_dir: Union[str, bytes],
|
||||
trainer: Trainer,
|
||||
safe_serialization: bool = False,
|
||||
save_compressed: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Synchronize processes, apply compression hooks, and save the model.
|
||||
|
||||
Args:
|
||||
model (PreTrainedModel): The model to be saved.
|
||||
output_dir (str or bytes): Path where the model files will be written.
|
||||
trainer (Trainer): Hugging Face Trainer for process synchronization.
|
||||
safe_serialization (bool): Use safe serialization if True.
|
||||
save_compressed (bool): Write compressed tensors if True.
|
||||
"""
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
if trainer.accelerator.is_main_process:
|
||||
modify_save_pretrained(model)
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
safe_serialization=safe_serialization,
|
||||
save_compressed=save_compressed,
|
||||
skip_sparsity_compression_stats=not save_compressed,
|
||||
)
|
||||
|
||||
# Only the main process writes the files
|
||||
if not trainer.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
modify_save_pretrained(model)
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
safe_serialization=safe_serialization,
|
||||
save_compressed=save_compressed,
|
||||
skip_sparsity_compression_stats=not save_compressed,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user