feat: add convert linear attention cli

This commit is contained in:
NanoCode012
2025-02-03 22:46:09 +07:00
parent 311d6eb5da
commit ce0cd470f7
5 changed files with 303 additions and 0 deletions

View File

@@ -0,0 +1,126 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import PluginManager
from axolotl.integrations.lolcats.linearize_attention import (
remove_base_attention,
toggle_attention,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import setup_trainer
LOG = logging.getLogger(__name__)
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Convert attention to linear attention and perform attention transfer via distillation.
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
cfg.load_in_8bit = False
cfg.load_in_4bit = False
cfg.adapter = None
# load model
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
# convert attention
from axolotl.integrations.lolcats.linearize_attention import convert_attention
model = convert_attention(
model, cfg.attention_config, train_attention=True, remove_base_attn=True
)
# Get datasets
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# toggle attention to be trainable
model = toggle_attention(model, train=True)
# Setup trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=(model, None, None),
tokenizer=tokenizer,
processor=None,
total_num_steps=total_num_steps,
)
# train
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
# drop base_attention + remove training attn
model = toggle_attention(model, train=False)
model = remove_base_attention(model)
# NOTE: If in peft mode, consider whether to auto-merge
# save model
save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path)
if hasattr(model, "config"):
model.config.save_pretrained(save_path)
safe_serialization = cfg.save_safetensors is True
# NOTE: may need to consider other ways of saving due to multi-gpu etc
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
# cleanup
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# load cfg, force linearize and add plugin to linearize
parsed_cfg = load_cfg(
config,
linearize=True,
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
**kwargs,
)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_linearize(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -0,0 +1,24 @@
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
https://github.com/HazyResearch/lolcats/
### Usage
Step 1:
```yaml
plugins:
- axolotl.integrations.lolcats.LinearizePlugin
linearize: true
```
Step 2: Remove the config above and finetune with lora with below possible targets.
```yaml
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
# with optional config below but this requires patching axolotl
# to allow this config to work with lora
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
```

View File

@@ -0,0 +1,31 @@
"""
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
Low-rank Linear Conversion via Attention Transfer
"""
import logging
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
DistillAttentionXentMSETrainer,
)
LOG = logging.getLogger("axolotl.integrations.lolcats")
class LinearizePlugin(BasePlugin):
"""
Plugin for lolcats integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lolcats.LinearAttentionArgs"
def get_trainer_cls(self, cfg):
# defualt to XentMSE
# TODO: add check to allow MSE_linear
if cfg.linearize:
return DistillAttentionXentMSETrainer
return None

View File

@@ -0,0 +1,13 @@
"""
Module for handling linear attention input arguments.
"""
from pydantic import BaseModel
class LinearAttentionArgs(BaseModel):
"""
Input args for linear attention
"""
attention_config: dict

View File

@@ -0,0 +1,109 @@
"""
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
"""
from typing import Any
from torch import Tensor, nn, tensor
from axolotl.core.trainers.base import AxolotlTrainer
class DistillAttentionXentMSETrainer(AxolotlTrainer):
"""
Custom trainer class for distilling attentions.
- We compute and store the attention outputs and/or weights for each head and layer,
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
"""
def __init__(
self,
model: nn.Module,
metric_for_best_model: str = "distill/eval/loss",
mse_factor: float = 1e3,
xent_factor: float = 0,
**kwargs: Any,
):
super().__init__(model=model, **kwargs)
self.metric_for_best_model = metric_for_best_model
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
self.criterion_mse = nn.MSELoss(reduction="mean")
self.mse_factor = mse_factor
self.xent_factor = xent_factor
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Tensor],
return_outputs=False,
num_items_in_batch=None,
) -> tuple[Tensor, dict]:
"""
Attention distillation ("attention transfer")
- For each layer and head, get attentions and train to
minimize some combo of MSE and cross-entropy loss
"""
# alias inputs to data
data = inputs
# Filter out labels
inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"}
# Forward pass
outputs = model(**inputs, output_attentions=True, use_cache=False)
outputs = outputs.get("attentions")
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
# n_layers x (predicted_attns, true_attns)
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
loss_mse = tensor(0.0)
loss_xent = tensor(0.0)
n_layers = 0 # Number of layers to distill
softmax_layers = []
for layer_idx, attns in enumerate(outputs):
if attns is not None:
if len(attns) != 2:
attns = attns.cpu()
else:
if self.xent_factor > 0:
# Cross-entropy loss
a_pred, a_true = attns[0]
a_pred = a_pred.clamp(
min=1e-12
).log() # nn.CrossEntropy assumes unnormalized logits
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
# Compute mean cross-entropy over all queries
a_pred = a_pred.contiguous().view(-1, k_len)
a_true = a_true.contiguous().view(-1, k_len)
loss_xent += self.criterion_xent(a_pred, a_true)
if self.mse_factor > 0:
loss_mse += self.criterion_mse(*attns[1])
n_layers += 1
else:
softmax_layers.append(layer_idx)
if n_layers > 0:
loss_xent = loss_xent / n_layers * self.xent_factor
loss_mse = loss_mse / n_layers * self.mse_factor
loss = loss_xent + loss_mse
if "position_ids" in data:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"input_len": data["position_ids"].shape[1],
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
else:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
return loss, outputs