feat: add convert linear attention cli
This commit is contained in:
126
src/axolotl/cli/convert_linear_attention.py
Normal file
126
src/axolotl/cli/convert_linear_attention.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.lolcats.linearize_attention import (
|
||||
remove_base_attention,
|
||||
toggle_attention,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
"""
|
||||
Convert attention to linear attention and perform attention transfer via distillation.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
|
||||
cfg.load_in_8bit = False
|
||||
cfg.load_in_4bit = False
|
||||
cfg.adapter = None
|
||||
|
||||
# load model
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# convert attention
|
||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||
|
||||
model = convert_attention(
|
||||
model, cfg.attention_config, train_attention=True, remove_base_attn=True
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# toggle attention to be trainable
|
||||
model = toggle_attention(model, train=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = setup_trainer(
|
||||
cfg=cfg,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
model=(model, None, None),
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||
|
||||
# drop base_attention + remove training attn
|
||||
model = toggle_attention(model, train=False)
|
||||
model = remove_base_attention(model)
|
||||
|
||||
# NOTE: If in peft mode, consider whether to auto-merge
|
||||
|
||||
# save model
|
||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||
tokenizer.save_pretrained(save_path)
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(save_path)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
# cleanup
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# load cfg, force linearize and add plugin to linearize
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
linearize=True,
|
||||
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_linearize(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
24
src/axolotl/integrations/lolcats/README.md
Normal file
24
src/axolotl/integrations/lolcats/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
|
||||
|
||||
https://github.com/HazyResearch/lolcats/
|
||||
|
||||
### Usage
|
||||
|
||||
Step 1:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.lolcats.LinearizePlugin
|
||||
|
||||
linearize: true
|
||||
```
|
||||
|
||||
Step 2: Remove the config above and finetune with lora with below possible targets.
|
||||
|
||||
```yaml
|
||||
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
|
||||
# with optional config below but this requires patching axolotl
|
||||
# to allow this config to work with lora
|
||||
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
||||
```
|
||||
31
src/axolotl/integrations/lolcats/__init__.py
Normal file
31
src/axolotl/integrations/lolcats/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
|
||||
|
||||
Low-rank Linear Conversion via Attention Transfer
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
|
||||
DistillAttentionXentMSETrainer,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.lolcats")
|
||||
|
||||
|
||||
class LinearizePlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for lolcats integration with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
# defualt to XentMSE
|
||||
# TODO: add check to allow MSE_linear
|
||||
if cfg.linearize:
|
||||
return DistillAttentionXentMSETrainer
|
||||
|
||||
return None
|
||||
13
src/axolotl/integrations/lolcats/args.py
Normal file
13
src/axolotl/integrations/lolcats/args.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Module for handling linear attention input arguments.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LinearAttentionArgs(BaseModel):
|
||||
"""
|
||||
Input args for linear attention
|
||||
"""
|
||||
|
||||
attention_config: dict
|
||||
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
|
||||
|
||||
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch import Tensor, nn, tensor
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer class for distilling attentions.
|
||||
- We compute and store the attention outputs and/or weights for each head and layer,
|
||||
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
|
||||
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
metric_for_best_model: str = "distill/eval/loss",
|
||||
mse_factor: float = 1e3,
|
||||
xent_factor: float = 0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.metric_for_best_model = metric_for_best_model
|
||||
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
||||
self.criterion_mse = nn.MSELoss(reduction="mean")
|
||||
self.mse_factor = mse_factor
|
||||
self.xent_factor = xent_factor
|
||||
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Tensor],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> tuple[Tensor, dict]:
|
||||
"""
|
||||
Attention distillation ("attention transfer")
|
||||
- For each layer and head, get attentions and train to
|
||||
minimize some combo of MSE and cross-entropy loss
|
||||
"""
|
||||
# alias inputs to data
|
||||
data = inputs
|
||||
|
||||
# Filter out labels
|
||||
inputs = {k: v.to(model.device) for k, v in data.items() if k != "labels"}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||
outputs = outputs.get("attentions")
|
||||
|
||||
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
||||
# n_layers x (predicted_attns, true_attns)
|
||||
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
||||
loss_mse = tensor(0.0)
|
||||
loss_xent = tensor(0.0)
|
||||
n_layers = 0 # Number of layers to distill
|
||||
softmax_layers = []
|
||||
for layer_idx, attns in enumerate(outputs):
|
||||
if attns is not None:
|
||||
if len(attns) != 2:
|
||||
attns = attns.cpu()
|
||||
else:
|
||||
if self.xent_factor > 0:
|
||||
# Cross-entropy loss
|
||||
a_pred, a_true = attns[0]
|
||||
a_pred = a_pred.clamp(
|
||||
min=1e-12
|
||||
).log() # nn.CrossEntropy assumes unnormalized logits
|
||||
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
|
||||
# Compute mean cross-entropy over all queries
|
||||
a_pred = a_pred.contiguous().view(-1, k_len)
|
||||
a_true = a_true.contiguous().view(-1, k_len)
|
||||
loss_xent += self.criterion_xent(a_pred, a_true)
|
||||
if self.mse_factor > 0:
|
||||
loss_mse += self.criterion_mse(*attns[1])
|
||||
n_layers += 1
|
||||
else:
|
||||
softmax_layers.append(layer_idx)
|
||||
if n_layers > 0:
|
||||
loss_xent = loss_xent / n_layers * self.xent_factor
|
||||
loss_mse = loss_mse / n_layers * self.mse_factor
|
||||
loss = loss_xent + loss_mse
|
||||
|
||||
if "position_ids" in data:
|
||||
outputs = {
|
||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||
"input_len": data["position_ids"].shape[1],
|
||||
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
|
||||
"mse_factor": self.mse_factor,
|
||||
"xent_factor": self.xent_factor,
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||
"mse_factor": self.mse_factor,
|
||||
"xent_factor": self.xent_factor,
|
||||
}
|
||||
return loss, outputs
|
||||
Reference in New Issue
Block a user