From 3da866b2b9a297dbea02cc7641c812cf0f43bd58 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 12 Mar 2025 07:09:06 +0000 Subject: [PATCH] Add: SFTPlugin with llmcompressor --- examples/llama-3/sft.yaml | 76 ++++++++++ .../llmcompressor_sft/__init__.py | 130 ++++++++++++++++++ .../integrations/llmcompressor_sft/args.py | 13 ++ src/axolotl/utils/models.py | 18 +++ 4 files changed, 237 insertions(+) create mode 100644 examples/llama-3/sft.yaml create mode 100644 src/axolotl/integrations/llmcompressor_sft/__init__.py create mode 100644 src/axolotl/integrations/llmcompressor_sft/args.py diff --git a/examples/llama-3/sft.yaml b/examples/llama-3/sft.yaml new file mode 100644 index 000000000..5077be697 --- /dev/null +++ b/examples/llama-3/sft.yaml @@ -0,0 +1,76 @@ +base_model: "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed" +# TODO: change to +# base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4 + +plugins: + - axolotl.integrations.llmcompressor_sft.SFTPlugin + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true +eval_sample_packing: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: <|end_of_text|> +recipe: + finetuning_stage: + finetuning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 \ No newline at end of file diff --git a/src/axolotl/integrations/llmcompressor_sft/__init__.py b/src/axolotl/integrations/llmcompressor_sft/__init__.py new file mode 100644 index 000000000..c888f6797 --- /dev/null +++ b/src/axolotl/integrations/llmcompressor_sft/__init__.py @@ -0,0 +1,130 @@ +""" +Sparse Finetuning plugin for Axolotl - enables handling of sparse neural networks +by maintaining masks for zero weights during training. +""" + +import logging +from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl +from transformers.training_args import TrainingArguments + +from ..base import BasePlugin +from .args import LLMCompressorArgs # pylint: disable=unused-import. # noqa: F401 +from llmcompressor import initialize +from llmcompressor.core import callbacks as session_callbacks +from llmcompressor.recipe import Recipe + +LOG = logging.getLogger("axolotl.integrations.llmcompressor_sft") + +class SFTCallbackHandler(TrainerCallback): + """ + Transformer trainer callback for Sparse Finetuning. + Maintains sparsity patterns during training by applying masks after optimization steps. + This ensures that optimizer updates to zero weights are canceled out. + """ + + def __init__(self, trainer: object, recipe: object): + """ + Initialize the callback handler. + + Args: + trainer (object): The trainer instance. + recipe (object): The sparse finetuning recipe to be applied. + """ + super().__init__() + self.trainer = trainer + self.recipe = Recipe.model_validate(recipe) + + if hasattr(self.trainer, "compute_loss"): + self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event triggered at the beginning of training. + Updates the session reference to the model, accommodating changes due to wrappers like FSDP. + """ + super().on_train_begin(args, state, control, **kwargs) + initialize( + model=self.trainer.model, + optimizer=self.trainer.optimizer, + start=state.epoch, + recipe=self.recipe, + ) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event triggered at the beginning of a training step. + Calls batch_start in the active CompressionSession. + """ + super().on_step_begin(args, state, control, **kwargs) + session_callbacks.batch_start() + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event triggered at the end of a training step. + Calls optimizer pre-step, post-step, and batch_end callbacks. + """ + super().on_step_end(args, state, control, **kwargs) + session_callbacks.optim_pre_step() + session_callbacks.optim_post_step() + session_callbacks.batch_end() + + def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event triggered at the end of a substep during gradient accumulation. + Calls batch_end in the active CompressionSession. + """ + super().on_substep_end(args, state, control, **kwargs) + session_callbacks.batch_end() + + # def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # super().on_prediction_step(args, state, control, **kwargs) + # session_callbacks.loss_calculated() + +class SFTPlugin(BasePlugin): + """ + Plugin for Sparse Finetuning integration with Axolotl. + """ + + def get_input_args(self) -> str: + """ + Returns the input argument path for the plugin. + """ + return "axolotl.integrations.llmcompressor_sft.LLMCompressorArgs" + + def add_callbacks_post_trainer(self, cfg, trainer): + """ + Adds Sparse Finetuning callback to the trainer. + + Args: + cfg (object): Configuration object containing the recipe. + trainer (object): Trainer instance to which the callback is added. + + Returns: + list: A list containing the Sparse Finetuning callback. + """ + LOG.info("Adding Sparse Finetuning callback to the trainer") + callback = SFTCallbackHandler( + trainer=trainer, + recipe=cfg.recipe, + ) + return [callback] + + +def compute_loss_wrapper(compute_loss_func): + """ + Wraps the loss computation function to integrate with the active CompressionSession. + + Args: + compute_loss_func (function): The original loss computation function. + + Returns: + function: Wrapped function that reports the computed loss. + """ + def wrapper(*args, **kwargs): + loss = compute_loss_func(*args, **kwargs) + session_callbacks.loss_calculated(loss=loss) + # take the mean across multiple GPUs + # this is done outside the compute_loss function in the parent + loss = loss.mean() + return loss + return wrapper \ No newline at end of file diff --git a/src/axolotl/integrations/llmcompressor_sft/args.py b/src/axolotl/integrations/llmcompressor_sft/args.py new file mode 100644 index 000000000..a117b862c --- /dev/null +++ b/src/axolotl/integrations/llmcompressor_sft/args.py @@ -0,0 +1,13 @@ +""" +Pydantic model for accepting `llmcompressor` specific arguments. +""" +from typing import Optional, Any +from pydantic import BaseModel + + +class LLMCompressorArgs(BaseModel): + """ + Input arguments for Sparse Finetuning. + """ + + recipe: Optional[Any] = None \ No newline at end of file diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e88de1bad..b18a43021 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -141,6 +141,24 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): hasattr(model_config, "quantization_config") and model_config.quantization_config ) + + # TODO: Use a better fix to handle + # config.json produced by compressed-tensors + # sparse-only model -> will also have a quantization_config + + is_sparse_only_quant_config = bool( + not quant_config_exists + or ( + quant_config_exists + and model_config.quantization_config["quant_method"] == "compressed-tensors" + and not model_config.quantization_config.get("config_groups", False) + and model_config.quantization_config.get("sparsity_config", False) + ) + ) + + if is_sparse_only_quant_config: + quant_config_exists = False + quant_config_method_is_gptq = ( quant_config_exists and "quant_method" in model_config.quantization_config