From ce0cd470f745e4537019b0a3a4ed1060ec277143 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 3 Feb 2025 22:46:09 +0700 Subject: [PATCH] feat: add convert linear attention cli --- src/axolotl/cli/convert_linear_attention.py | 126 ++++++++++++++++++ src/axolotl/integrations/lolcats/README.md | 24 ++++ src/axolotl/integrations/lolcats/__init__.py | 31 +++++ src/axolotl/integrations/lolcats/args.py | 13 ++ .../trainer/distill_attention_xent_mse.py | 109 +++++++++++++++ 5 files changed, 303 insertions(+) create mode 100644 src/axolotl/cli/convert_linear_attention.py create mode 100644 src/axolotl/integrations/lolcats/README.md create mode 100644 src/axolotl/integrations/lolcats/__init__.py create mode 100644 src/axolotl/integrations/lolcats/args.py create mode 100644 src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py diff --git a/src/axolotl/cli/convert_linear_attention.py b/src/axolotl/cli/convert_linear_attention.py new file mode 100644 index 000000000..061e3fe8b --- /dev/null +++ b/src/axolotl/cli/convert_linear_attention.py @@ -0,0 +1,126 @@ +"""CLI to run training on a model.""" + +import logging +import os +from pathlib import Path +from typing import Union + +import fire +from dotenv import load_dotenv +from transformers.hf_argparser import HfArgumentParser + +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg +from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.common.datasets import load_datasets +from axolotl.integrations.base import PluginManager +from axolotl.integrations.lolcats.linearize_attention import ( + remove_base_attention, + toggle_attention, +) +from axolotl.utils.dict import DictDefault +from axolotl.utils.trainer import setup_trainer + +LOG = logging.getLogger(__name__) + + +def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: + """ + Convert attention to linear attention and perform attention transfer via distillation. + """ + print_axolotl_text_art() + check_accelerate_default_config() + check_user_token() + + # ensure quantization and peft are turned off (due to how we need to re-apply peft later) + cfg.load_in_8bit = False + cfg.load_in_4bit = False + cfg.adapter = None + + # load model + model, tokenizer = load_model_and_tokenizer(cfg=cfg) + + # convert attention + from axolotl.integrations.lolcats.linearize_attention import convert_attention + + model = convert_attention( + model, cfg.attention_config, train_attention=True, remove_base_attn=True + ) + + # Get datasets + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + train_dataset = dataset_meta.train_dataset + eval_dataset = dataset_meta.eval_dataset + total_num_steps = dataset_meta.total_num_steps + + # toggle attention to be trainable + model = toggle_attention(model, train=True) + + # Setup trainer + trainer = setup_trainer( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model=(model, None, None), + tokenizer=tokenizer, + processor=None, + total_num_steps=total_num_steps, + ) + + # train + trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) + + # drop base_attention + remove training attn + model = toggle_attention(model, train=False) + model = remove_base_attention(model) + + # NOTE: If in peft mode, consider whether to auto-merge + + # save model + save_path = str(os.path.join(cfg.output_dir, "distilled")) + tokenizer.save_pretrained(save_path) + if hasattr(model, "config"): + model.config.save_pretrained(save_path) + + safe_serialization = cfg.save_safetensors is True + # NOTE: may need to consider other ways of saving due to multi-gpu etc + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + + # cleanup + plugin_manager = PluginManager.get_instance() + + del model + del tokenizer + + plugin_manager.post_train_unload(cfg) + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_train`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ + # load cfg, force linearize and add plugin to linearize + parsed_cfg = load_cfg( + config, + linearize=True, + plugins=["axolotl.integrations.lolcats.LinearizePlugin"], + **kwargs, + ) + + parser = HfArgumentParser(TrainerCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + do_linearize(parsed_cfg, parsed_cli_args) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/integrations/lolcats/README.md b/src/axolotl/integrations/lolcats/README.md new file mode 100644 index 000000000..7b60f06f3 --- /dev/null +++ b/src/axolotl/integrations/lolcats/README.md @@ -0,0 +1,24 @@ +# Low-rank Linear Conversion via Attention Transfer (LoLCATs) + +https://github.com/HazyResearch/lolcats/ + +### Usage + +Step 1: + +```yaml +plugins: + - axolotl.integrations.lolcats.LinearizePlugin + +linearize: true +``` + +Step 2: Remove the config above and finetune with lora with below possible targets. + +```yaml +lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + +# with optional config below but this requires patching axolotl +# to allow this config to work with lora +# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*'] +``` diff --git a/src/axolotl/integrations/lolcats/__init__.py b/src/axolotl/integrations/lolcats/__init__.py new file mode 100644 index 000000000..a815721fa --- /dev/null +++ b/src/axolotl/integrations/lolcats/__init__.py @@ -0,0 +1,31 @@ +""" +Module for the Plugin for LoLCATs linear attention integration with Axolotl. + +Low-rank Linear Conversion via Attention Transfer +""" + +import logging + +from axolotl.integrations.base import BasePlugin +from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import ( + DistillAttentionXentMSETrainer, +) + +LOG = logging.getLogger("axolotl.integrations.lolcats") + + +class LinearizePlugin(BasePlugin): + """ + Plugin for lolcats integration with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.lolcats.LinearAttentionArgs" + + def get_trainer_cls(self, cfg): + # defualt to XentMSE + # TODO: add check to allow MSE_linear + if cfg.linearize: + return DistillAttentionXentMSETrainer + + return None diff --git a/src/axolotl/integrations/lolcats/args.py b/src/axolotl/integrations/lolcats/args.py new file mode 100644 index 000000000..0c2bd6561 --- /dev/null +++ b/src/axolotl/integrations/lolcats/args.py @@ -0,0 +1,13 @@ +""" +Module for handling linear attention input arguments. +""" + +from pydantic import BaseModel + + +class LinearAttentionArgs(BaseModel): + """ + Input args for linear attention + """ + + attention_config: dict diff --git a/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py new file mode 100644 index 000000000..582b21057 --- /dev/null +++ b/src/axolotl/integrations/lolcats/trainer/distill_attention_xent_mse.py @@ -0,0 +1,109 @@ +""" +Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer. + +In this implementation we support using either just the softmax attention outputs, or the softmax attention weights. +""" + +from typing import Any + +from torch import Tensor, nn, tensor + +from axolotl.core.trainers.base import AxolotlTrainer + + +class DistillAttentionXentMSETrainer(AxolotlTrainer): + """ + Custom trainer class for distilling attentions. + - We compute and store the attention outputs and/or weights for each head and layer, + for both the "teacher" softmax attentions and "student" learnable subquadratic attentions + - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) + """ + + def __init__( + self, + model: nn.Module, + metric_for_best_model: str = "distill/eval/loss", + mse_factor: float = 1e3, + xent_factor: float = 0, + **kwargs: Any, + ): + super().__init__(model=model, **kwargs) + self.metric_for_best_model = metric_for_best_model + self.criterion_xent = nn.CrossEntropyLoss(reduction="mean") + self.criterion_mse = nn.MSELoss(reduction="mean") + self.mse_factor = mse_factor + self.xent_factor = xent_factor + # self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary + + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, Tensor], + return_outputs=False, + num_items_in_batch=None, + ) -> tuple[Tensor, dict]: + """ + Attention distillation ("attention transfer") + - For each layer and head, get attentions and train to + minimize some combo of MSE and cross-entropy loss + """ + # alias inputs to data + data = inputs + + # Filter out labels + inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"} + + # Forward pass + outputs = model(**inputs, output_attentions=True, use_cache=False) + outputs = outputs.get("attentions") + + # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] + # n_layers x (predicted_attns, true_attns) + # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len) + loss_mse = tensor(0.0) + loss_xent = tensor(0.0) + n_layers = 0 # Number of layers to distill + softmax_layers = [] + for layer_idx, attns in enumerate(outputs): + if attns is not None: + if len(attns) != 2: + attns = attns.cpu() + else: + if self.xent_factor > 0: + # Cross-entropy loss + a_pred, a_true = attns[0] + a_pred = a_pred.clamp( + min=1e-12 + ).log() # nn.CrossEntropy assumes unnormalized logits + k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len + # Compute mean cross-entropy over all queries + a_pred = a_pred.contiguous().view(-1, k_len) + a_true = a_true.contiguous().view(-1, k_len) + loss_xent += self.criterion_xent(a_pred, a_true) + if self.mse_factor > 0: + loss_mse += self.criterion_mse(*attns[1]) + n_layers += 1 + else: + softmax_layers.append(layer_idx) + if n_layers > 0: + loss_xent = loss_xent / n_layers * self.xent_factor + loss_mse = loss_mse / n_layers * self.mse_factor + loss = loss_xent + loss_mse + + if "position_ids" in data: + outputs = { + "loss_xent": loss_xent.item() if self.xent_factor > 0 else 0, + "loss_mse": loss_mse if self.mse_factor > 0 else 0, + "input_len": data["position_ids"].shape[1], + "position_ids": data["position_ids"][0].detach().cpu().numpy(), + "mse_factor": self.mse_factor, + "xent_factor": self.xent_factor, + } + else: + outputs = { + "loss_xent": loss_xent.item() if self.xent_factor > 0 else 0, + "loss_mse": loss_mse if self.mse_factor > 0 else 0, + "mse_factor": self.mse_factor, + "xent_factor": self.xent_factor, + } + return loss, outputs