Rebase and updates!
This commit is contained in:
@@ -74,3 +74,4 @@ llmcompressor:
|
|||||||
're:.*down_proj.weight',
|
're:.*down_proj.weight',
|
||||||
]
|
]
|
||||||
start: 0
|
start: 0
|
||||||
|
save_compressed: true
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -149,8 +149,9 @@ extras_require = {
|
|||||||
"vllm": [
|
"vllm": [
|
||||||
"vllm==0.7.2",
|
"vllm==0.7.2",
|
||||||
],
|
],
|
||||||
|
# PENDING: https://github.com/vllm-project/llm-compressor/pull/1352
|
||||||
"llmcompressor": [
|
"llmcompressor": [
|
||||||
"llmcompressor~=0.5.0",
|
"llmcompressor==0.5.1",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,9 +20,13 @@ class CompressionArgs(BaseModel):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
save_compressed: Annotated[
|
||||||
validate_assignment=True,
|
bool,
|
||||||
)
|
Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to save the compressed model after training.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class LLMCompressorArgs(BaseModel):
|
class LLMCompressorArgs(BaseModel):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import logging
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, ParamSpec, TypeVar
|
from typing import Any, Callable, ParamSpec, TypeVar
|
||||||
|
|
||||||
from llmcompressor import active_session
|
from llmcompressor import active_session, create_session
|
||||||
from llmcompressor.core import callbacks as session_callbacks
|
from llmcompressor.core import callbacks as session_callbacks
|
||||||
from llmcompressor.recipe import Recipe
|
from llmcompressor.recipe import Recipe
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
@@ -43,6 +43,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
|||||||
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
|
||||||
)
|
)
|
||||||
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||||
|
create_session()
|
||||||
|
|
||||||
def on_train_begin(
|
def on_train_begin(
|
||||||
self,
|
self,
|
||||||
@@ -60,13 +61,14 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
|||||||
control (TrainerControl): Trainer control.
|
control (TrainerControl): Trainer control.
|
||||||
"""
|
"""
|
||||||
super().on_train_begin(args, state, control, **kwargs)
|
super().on_train_begin(args, state, control, **kwargs)
|
||||||
session = active_session()
|
self.trainer.accelerator.wait_for_everyone()
|
||||||
session.initialize(
|
active_session().initialize(
|
||||||
model=self.trainer.model,
|
model=self.trainer.model,
|
||||||
optimizer=self.trainer.optimizer,
|
optimizer=self.trainer.optimizer,
|
||||||
start=state.epoch,
|
start=state.epoch,
|
||||||
recipe=self.recipe,
|
recipe=self.recipe,
|
||||||
)
|
)
|
||||||
|
self.trainer.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self,
|
self,
|
||||||
@@ -107,8 +109,7 @@ class LLMCompressorCallbackHandler(TrainerCallback):
|
|||||||
Called at the end of training. Finalizes the compression session.
|
Called at the end of training. Finalizes the compression session.
|
||||||
"""
|
"""
|
||||||
super().on_train_end(args, state, control, **kwargs)
|
super().on_train_end(args, state, control, **kwargs)
|
||||||
session = active_session()
|
active_session().finalize()
|
||||||
session.finalize()
|
|
||||||
|
|
||||||
|
|
||||||
class LLMCompressorPlugin(BasePlugin):
|
class LLMCompressorPlugin(BasePlugin):
|
||||||
@@ -158,7 +159,8 @@ def compute_loss_wrapper(compute_loss_func: Callable[P, R]) -> Callable[P, R]:
|
|||||||
@wraps(compute_loss_func)
|
@wraps(compute_loss_func)
|
||||||
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
|
def compute_and_notify(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
loss = compute_loss_func(*args, **kwargs)
|
loss = compute_loss_func(*args, **kwargs)
|
||||||
session_callbacks.loss_calculated(loss=loss)
|
if active_session().lifecycle.initialized_:
|
||||||
|
session_callbacks.loss_calculated(loss=loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
return compute_and_notify
|
return compute_and_notify
|
||||||
|
|||||||
15
src/axolotl/integrations/llm_compressor/utils.py
Normal file
15
src/axolotl/integrations/llm_compressor/utils.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
def save_compressed_model(
|
||||||
|
model, output_dir, trainer: Trainer, safe_serialization: bool, save_compressed:bool
|
||||||
|
):
|
||||||
|
from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
modify_save_pretrained(model)
|
||||||
|
model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
save_compressed=save_compressed,
|
||||||
|
skip_sparsity_compression_stats=not save_compressed,
|
||||||
|
)
|
||||||
@@ -271,6 +271,19 @@ def save_trained_model(
|
|||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||||
|
from axolotl.integrations.llm_compressor.utils import (
|
||||||
|
save_compressed_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_compressed_model(
|
||||||
|
model=model,
|
||||||
|
output_dir=cfg.output_dir,
|
||||||
|
trainer=trainer,
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
save_compressed=cfg.llmcompressor.save_compressed,
|
||||||
|
)
|
||||||
|
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
@@ -279,6 +292,7 @@ def save_trained_model(
|
|||||||
trainer.model.save_pretrained(
|
trainer.model.save_pretrained(
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user