Add: SFTPlugin with llmcompressor
This commit is contained in:
76
examples/llama-3/sft.yaml
Normal file
76
examples/llama-3/sft.yaml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
|
||||||
|
# TODO: change to
|
||||||
|
# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.llmcompressor_sft.SFTPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
# gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 2
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
|
recipe:
|
||||||
|
finetuning_stage:
|
||||||
|
finetuning_modifiers:
|
||||||
|
ConstantPruningModifier:
|
||||||
|
targets: [
|
||||||
|
're:.*q_proj.weight',
|
||||||
|
're:.*k_proj.weight',
|
||||||
|
're:.*v_proj.weight',
|
||||||
|
're:.*o_proj.weight',
|
||||||
|
're:.*gate_proj.weight',
|
||||||
|
're:.*up_proj.weight',
|
||||||
|
're:.*down_proj.weight',
|
||||||
|
]
|
||||||
|
start: 0
|
||||||
130
src/axolotl/integrations/llmcompressor_sft/__init__.py
Normal file
130
src/axolotl/integrations/llmcompressor_sft/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks
|
||||||
|
by maintaining masks for zero weights during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
|
|
||||||
|
from ..base import BasePlugin
|
||||||
|
from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
from llmcompressor import initialize
|
||||||
|
from llmcompressor.core import callbacks as session_callbacks
|
||||||
|
from llmcompressor.recipe import Recipe
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft")
|
||||||
|
|
||||||
|
class SFTCallbackHandler(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Transformer trainer callback for Sparse Finetuning.
|
||||||
|
Maintains sparsity patterns during training by applying masks after optimization steps.
|
||||||
|
This ensures that optimizer updates to zero weights are canceled out.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trainer: object, recipe: object):
|
||||||
|
"""
|
||||||
|
Initialize the callback handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer (object): The trainer instance.
|
||||||
|
recipe (object): The sparse finetuning recipe to be applied.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.trainer = trainer
|
||||||
|
self.recipe = Recipe.model_validate(recipe)
|
||||||
|
|
||||||
|
if hasattr(self.trainer, "compute_loss"):
|
||||||
|
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event triggered at the beginning of training.
|
||||||
|
Updates the session reference to the model, accommodating changes due to wrappers like FSDP.
|
||||||
|
"""
|
||||||
|
super().on_train_begin(args, state, control, **kwargs)
|
||||||
|
initialize(
|
||||||
|
model=self.trainer.model,
|
||||||
|
optimizer=self.trainer.optimizer,
|
||||||
|
start=state.epoch,
|
||||||
|
recipe=self.recipe,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event triggered at the beginning of a training step.
|
||||||
|
Calls batch_start in the active CompressionSession.
|
||||||
|
"""
|
||||||
|
super().on_step_begin(args, state, control, **kwargs)
|
||||||
|
session_callbacks.batch_start()
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event triggered at the end of a training step.
|
||||||
|
Calls optimizer pre-step, post-step, and batch_end callbacks.
|
||||||
|
"""
|
||||||
|
super().on_step_end(args, state, control, **kwargs)
|
||||||
|
session_callbacks.optim_pre_step()
|
||||||
|
session_callbacks.optim_post_step()
|
||||||
|
session_callbacks.batch_end()
|
||||||
|
|
||||||
|
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event triggered at the end of a substep during gradient accumulation.
|
||||||
|
Calls batch_end in the active CompressionSession.
|
||||||
|
"""
|
||||||
|
super().on_substep_end(args, state, control, **kwargs)
|
||||||
|
session_callbacks.batch_end()
|
||||||
|
|
||||||
|
# def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
# super().on_prediction_step(args, state, control, **kwargs)
|
||||||
|
# session_callbacks.loss_calculated()
|
||||||
|
|
||||||
|
class SFTPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for Sparse Finetuning integration with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the input argument path for the plugin.
|
||||||
|
"""
|
||||||
|
return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs"
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
"""
|
||||||
|
Adds Sparse Finetuning callback to the trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (object): Configuration object containing the recipe.
|
||||||
|
trainer (object): Trainer instance to which the callback is added.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list containing the Sparse Finetuning callback.
|
||||||
|
"""
|
||||||
|
LOG.info("Adding Sparse Finetuning callback to the trainer")
|
||||||
|
callback = SFTCallbackHandler(
|
||||||
|
trainer=trainer,
|
||||||
|
recipe=cfg.recipe,
|
||||||
|
)
|
||||||
|
return [callback]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss_wrapper(compute_loss_func):
|
||||||
|
"""
|
||||||
|
Wraps the loss computation function to integrate with the active CompressionSession.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compute_loss_func (function): The original loss computation function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Wrapped function that reports the computed loss.
|
||||||
|
"""
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
loss = compute_loss_func(*args, **kwargs)
|
||||||
|
session_callbacks.loss_calculated(loss=loss)
|
||||||
|
# take the mean across multiple GPUs
|
||||||
|
# this is done outside the compute_loss function in the parent
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
return wrapper
|
||||||
13
src/axolotl/integrations/llmcompressor_sft/args.py
Normal file
13
src/axolotl/integrations/llmcompressor_sft/args.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Pydantic model for accepting `llmcompressor` specific arguments.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCompressorArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input arguments for Sparse Finetuning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
recipe: Optional[Any] = None
|
||||||
@@ -141,6 +141,24 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
|||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
and model_config.quantization_config
|
and model_config.quantization_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Use a better fix to handle
|
||||||
|
# config.json produced by compressed-tensors
|
||||||
|
# sparse-only model -> will also have a quantization_config
|
||||||
|
|
||||||
|
is_sparse_only_quant_config = bool(
|
||||||
|
not quant_config_exists
|
||||||
|
or (
|
||||||
|
quant_config_exists
|
||||||
|
and model_config.quantization_config["quant_method"] == "compressed-tensors"
|
||||||
|
and not model_config.quantization_config.get("config_groups", False)
|
||||||
|
and model_config.quantization_config.get("sparsity_config", False)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sparse_only_quant_config:
|
||||||
|
quant_config_exists = False
|
||||||
|
|
||||||
quant_config_method_is_gptq = (
|
quant_config_method_is_gptq = (
|
||||||
quant_config_exists
|
quant_config_exists
|
||||||
and "quant_method" in model_config.quantization_config
|
and "quant_method" in model_config.quantization_config
|
||||||
|
|||||||
Reference in New Issue
Block a user