feat: add convert linear attention cli
This commit is contained in:
126
src/axolotl/cli/convert_linear_attention.py
Normal file
126
src/axolotl/cli/convert_linear_attention.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
|
from axolotl.cli.config import load_cfg
|
||||||
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.integrations.lolcats.linearize_attention import (
|
||||||
|
remove_base_attention,
|
||||||
|
toggle_attention,
|
||||||
|
)
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
|
"""
|
||||||
|
Convert attention to linear attention and perform attention transfer via distillation.
|
||||||
|
"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
|
|
||||||
|
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
|
||||||
|
cfg.load_in_8bit = False
|
||||||
|
cfg.load_in_4bit = False
|
||||||
|
cfg.adapter = None
|
||||||
|
|
||||||
|
# load model
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
|
# convert attention
|
||||||
|
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||||
|
|
||||||
|
model = convert_attention(
|
||||||
|
model, cfg.attention_config, train_attention=True, remove_base_attn=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get datasets
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
train_dataset = dataset_meta.train_dataset
|
||||||
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
|
# toggle attention to be trainable
|
||||||
|
model = toggle_attention(model, train=True)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = setup_trainer(
|
||||||
|
cfg=cfg,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
model=(model, None, None),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=None,
|
||||||
|
total_num_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# train
|
||||||
|
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||||
|
|
||||||
|
# drop base_attention + remove training attn
|
||||||
|
model = toggle_attention(model, train=False)
|
||||||
|
model = remove_base_attention(model)
|
||||||
|
|
||||||
|
# NOTE: If in peft mode, consider whether to auto-merge
|
||||||
|
|
||||||
|
# save model
|
||||||
|
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||||
|
tokenizer.save_pretrained(save_path)
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
model.config.save_pretrained(save_path)
|
||||||
|
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
||||||
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
|
||||||
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Path to `axolotl` config YAML file.
|
||||||
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
|
"""
|
||||||
|
# load cfg, force linearize and add plugin to linearize
|
||||||
|
parsed_cfg = load_cfg(
|
||||||
|
config,
|
||||||
|
linearize=True,
|
||||||
|
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = HfArgumentParser(TrainerCliArgs)
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
|
||||||
|
do_linearize(parsed_cfg, parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
|
fire.Fire(do_cli)
|
||||||
24
src/axolotl/integrations/lolcats/README.md
Normal file
24
src/axolotl/integrations/lolcats/README.md
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
|
||||||
|
|
||||||
|
https://github.com/HazyResearch/lolcats/
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
Step 1:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.lolcats.LinearizePlugin
|
||||||
|
|
||||||
|
linearize: true
|
||||||
|
```
|
||||||
|
|
||||||
|
Step 2: Remove the config above and finetune with lora with below possible targets.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
|
|
||||||
|
# with optional config below but this requires patching axolotl
|
||||||
|
# to allow this config to work with lora
|
||||||
|
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
||||||
|
```
|
||||||
31
src/axolotl/integrations/lolcats/__init__.py
Normal file
31
src/axolotl/integrations/lolcats/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
|
||||||
|
|
||||||
|
Low-rank Linear Conversion via Attention Transfer
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
|
||||||
|
DistillAttentionXentMSETrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.lolcats")
|
||||||
|
|
||||||
|
|
||||||
|
class LinearizePlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for lolcats integration with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg):
|
||||||
|
# defualt to XentMSE
|
||||||
|
# TODO: add check to allow MSE_linear
|
||||||
|
if cfg.linearize:
|
||||||
|
return DistillAttentionXentMSETrainer
|
||||||
|
|
||||||
|
return None
|
||||||
13
src/axolotl/integrations/lolcats/args.py
Normal file
13
src/axolotl/integrations/lolcats/args.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Module for handling linear attention input arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
attention_config: dict
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
|
||||||
|
|
||||||
|
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from torch import Tensor, nn, tensor
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Custom trainer class for distilling attentions.
|
||||||
|
- We compute and store the attention outputs and/or weights for each head and layer,
|
||||||
|
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
|
||||||
|
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
metric_for_best_model: str = "distill/eval/loss",
|
||||||
|
mse_factor: float = 1e3,
|
||||||
|
xent_factor: float = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__(model=model, **kwargs)
|
||||||
|
self.metric_for_best_model = metric_for_best_model
|
||||||
|
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
||||||
|
self.criterion_mse = nn.MSELoss(reduction="mean")
|
||||||
|
self.mse_factor = mse_factor
|
||||||
|
self.xent_factor = xent_factor
|
||||||
|
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, Tensor],
|
||||||
|
return_outputs=False,
|
||||||
|
num_items_in_batch=None,
|
||||||
|
) -> tuple[Tensor, dict]:
|
||||||
|
"""
|
||||||
|
Attention distillation ("attention transfer")
|
||||||
|
- For each layer and head, get attentions and train to
|
||||||
|
minimize some combo of MSE and cross-entropy loss
|
||||||
|
"""
|
||||||
|
# alias inputs to data
|
||||||
|
data = inputs
|
||||||
|
|
||||||
|
# Filter out labels
|
||||||
|
inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"}
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||||
|
outputs = outputs.get("attentions")
|
||||||
|
|
||||||
|
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
# n_layers x (predicted_attns, true_attns)
|
||||||
|
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
||||||
|
loss_mse = tensor(0.0)
|
||||||
|
loss_xent = tensor(0.0)
|
||||||
|
n_layers = 0 # Number of layers to distill
|
||||||
|
softmax_layers = []
|
||||||
|
for layer_idx, attns in enumerate(outputs):
|
||||||
|
if attns is not None:
|
||||||
|
if len(attns) != 2:
|
||||||
|
attns = attns.cpu()
|
||||||
|
else:
|
||||||
|
if self.xent_factor > 0:
|
||||||
|
# Cross-entropy loss
|
||||||
|
a_pred, a_true = attns[0]
|
||||||
|
a_pred = a_pred.clamp(
|
||||||
|
min=1e-12
|
||||||
|
).log() # nn.CrossEntropy assumes unnormalized logits
|
||||||
|
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
|
||||||
|
# Compute mean cross-entropy over all queries
|
||||||
|
a_pred = a_pred.contiguous().view(-1, k_len)
|
||||||
|
a_true = a_true.contiguous().view(-1, k_len)
|
||||||
|
loss_xent += self.criterion_xent(a_pred, a_true)
|
||||||
|
if self.mse_factor > 0:
|
||||||
|
loss_mse += self.criterion_mse(*attns[1])
|
||||||
|
n_layers += 1
|
||||||
|
else:
|
||||||
|
softmax_layers.append(layer_idx)
|
||||||
|
if n_layers > 0:
|
||||||
|
loss_xent = loss_xent / n_layers * self.xent_factor
|
||||||
|
loss_mse = loss_mse / n_layers * self.mse_factor
|
||||||
|
loss = loss_xent + loss_mse
|
||||||
|
|
||||||
|
if "position_ids" in data:
|
||||||
|
outputs = {
|
||||||
|
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||||
|
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||||
|
"input_len": data["position_ids"].shape[1],
|
||||||
|
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
|
||||||
|
"mse_factor": self.mse_factor,
|
||||||
|
"xent_factor": self.xent_factor,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
outputs = {
|
||||||
|
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||||
|
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||||
|
"mse_factor": self.mse_factor,
|
||||||
|
"xent_factor": self.xent_factor,
|
||||||
|
}
|
||||||
|
return loss, outputs
|
||||||
Reference in New Issue
Block a user