Rebase and updates!
This commit is contained in:
@@ -74,3 +74,4 @@ llmcompressor:
|
||||
're:.*down_proj.weight',
|
||||
]
|
||||
start: 0
|
||||
save_compressed: true
|
||||
|
||||
3
setup.py
3
setup.py
@@ -149,8 +149,9 @@ extras_require = {
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
# PENDING: https://github.com/vllm-project/llm-compressor/pull/1352
|
||||
"llmcompressor": [
|
||||
"llmcompressor~=0.5.0",
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -20,9 +20,13 @@ class CompressionArgs(BaseModel):
|
||||
),
|
||||
]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
)
|
||||
save_compressed: Annotated[
|
||||
bool,
|
||||
Field(
|
||||
default=False,
|
||||
description="Whether to save the compressed model after training.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class LLMCompressorArgs(BaseModel):
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, ParamSpec, TypeVar
|
||||
|
||||
from llmcompressor import active_session
|
||||
from llmcompressor import active_session, create_session
|
||||
from llmcompressor.core import callbacks as session_callbacks
|
||||
from llmcompressor.recipe import Recipe
|
||||
from transformers.trainer import Trainer
|
||||
@@ -43,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
||||
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||
)
|
||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||
create_session()
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
@@ -60,13 +61,14 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
||||
control (TrainerControl): Trainer control.
|
||||
"""
|
||||
super().on_train_begin(args, state, control, **kwargs)
|
||||
session = active_session()
|
||||
session.initialize(
|
||||
self.trainer.accelerator.wait_for_everyone()
|
||||
active_session().initialize(
|
||||
model=self.trainer.model,
|
||||
optimizer=self.trainer.optimizer,
|
||||
start=state.epoch,
|
||||
recipe=self.recipe,
|
||||
)
|
||||
self.trainer.accelerator.wait_for_everyone()
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
@@ -107,8 +109,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
||||
Called at the end of training. Finalizes the compression session.
|
||||
"""
|
||||
super().on_train_end(args, state, control, **kwargs)
|
||||
session = active_session()
|
||||
session.finalize()
|
||||
active_session().finalize()
|
||||
|
||||
|
||||
class LLMCompressorPlugin(BasePlugin):
|
||||
@@ -158,7 +159,8 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(compute_loss_func)
|
||||
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
loss = compute_loss_func(*args, **kwargs)
|
||||
session_callbacks.loss_calculated(loss=loss)
|
||||
if active_session().lifecycle.initialized_:
|
||||
session_callbacks.loss_calculated(loss=loss)
|
||||
return loss
|
||||
|
||||
return compute_and_notify
|
||||
|
||||
15
src/axolotl/integrations/llm_compressor/utils.py
Normal file
15
src/axolotl/integrations/llm_compressor/utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from transformers import Trainer
|
||||
|
||||
def save_compressed_model(
|
||||
model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool
|
||||
):
|
||||
from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
if trainer.accelerator.is_main_process:
|
||||
modify_save_pretrained(model)
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
safe_serialization=safe_serialization,
|
||||
save_compressed=save_compressed,
|
||||
skip_sparsity_compression_stats=not save_compressed,
|
||||
)
|
||||
@@ -288,6 +288,19 @@ def save_trained_model(
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||
from axolotl.integrations.llm_compressor.utils import (
|
||||
save_compressed_model,
|
||||
)
|
||||
|
||||
save_compressed_model(
|
||||
model=model,
|
||||
output_dir=cfg.output_dir,
|
||||
trainer=trainer,
|
||||
safe_serialization=safe_serialization,
|
||||
save_compressed=cfg.llmcompressor.save_compressed,
|
||||
)
|
||||
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
@@ -296,6 +309,7 @@ def save_trained_model(
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user