Compare commits

..

40 Commits

Author SHA1 Message Date
Wing Lian
6905711e45 set max steps to -1 when empty 2025-02-06 17:27:52 -05:00
Wing Lian
bb5a6135eb don't set total num steps for grpo 2025-02-06 17:23:13 -05:00
Wing Lian
e637f9b1a4 cleanup pythonpath if axo in it 2025-02-06 17:03:21 -05:00
Wing Lian
1a3bfd6e0f test not deleting pythonpath for custom code bundling
clean path and add mounts
handle mounting
2025-02-06 17:01:19 -05:00
Wing Lian
3df4df868c make sure to pass kwargs when using accelerate 2025-02-06 14:00:15 -05:00
Wing Lian
c82cbdc6d9 make sure to handle num-processes with cloud 2025-02-06 13:50:39 -05:00
Wing Lian
ecea44c902 fix num_processes in passing to accelerate 2025-02-06 13:39:46 -05:00
Wing Lian
4f9c57e95d check for src axolotl in PYTHONPATH before removing it 2025-02-06 13:26:23 -05:00
Wing Lian
3d38bc82b8 include vllm in build 2025-02-06 11:09:42 -05:00
Wing Lian
756a8332d6 set default on trl config 2025-02-05 22:17:10 -05:00
Wing Lian
aded9c500d refactor cfg.grpo_* to use cfg.trl.* 2025-02-05 20:41:14 -05:00
Wing Lian
3659d812f7 use cfg.max_completion_length, not sequence_len 2025-02-05 13:20:17 -05:00
Salman Mohammadi
bdb0f97082 adding 'reward_processing_classes' 2025-02-05 18:18:42 +00:00
Salman Mohammadi
65b6519447 adding 'reward_processing_classes' 2025-02-05 18:13:05 +00:00
Wing Lian
a1958b09de seperately include max_completion_len 2025-02-05 13:01:52 -05:00
Salman Mohammadi
b8f258817e adding reward fn verification 2025-02-05 13:30:02 +00:00
Wing Lian
753146b458 max_length moved to reward config 2025-02-04 11:06:26 -05:00
Wing Lian
d683c50113 fix config cls 2025-02-04 11:06:26 -05:00
Wing Lian
234cd8311e fix failure case in prompter loading 2025-02-04 11:06:26 -05:00
Wing Lian
f9893e3842 fix dpo config and add use_logits_to_keep 2025-02-04 11:06:26 -05:00
Wing Lian
ac1ebc58a8 add support for num_generations 2025-02-04 11:06:25 -05:00
Wing Lian
56f3b9f20f bump pydantic to support vllm 2025-02-04 11:06:25 -05:00
Wing Lian
2c1376d8c4 don't shrink embeddings unless told to 2025-02-04 11:06:25 -05:00
Wing Lian
3c7517fd55 add support for passing map kwargs to dataset map in rl 2025-02-04 11:06:25 -05:00
Wing Lian
1e94d7ef65 more fixes to get grpo working 2025-02-04 11:06:25 -05:00
Wing Lian
cfc7fe0df2 remove ununsable args kwargs 2025-02-04 11:06:25 -05:00
Wing Lian
3c4fe478cf be nice with self.cfg.dataset_processes 2025-02-04 11:06:25 -05:00
Wing Lian
c810599c66 order matters 2025-02-04 11:06:24 -05:00
Wing Lian
300ffc2cb6 make it a dataclass 2025-02-04 11:06:24 -05:00
Wing Lian
b1c4711145 load the class from strat 2025-02-04 11:06:24 -05:00
Wing Lian
d155849e2c use correct builder 2025-02-04 11:06:24 -05:00
Wing Lian
626db6cb84 collator for grpo and prompt loader 2025-02-04 11:06:24 -05:00
Wing Lian
79159b4871 support custom module prompt strategy for rl 2025-02-04 11:06:24 -05:00
Wing Lian
704ddd6ff1 honor skip prepare for rl 2025-02-04 11:06:24 -05:00
Wing Lian
54b0d3d0e8 passthrough dataset parser for dpo/grpo 2025-02-04 11:06:23 -05:00
Wing Lian
59ad21f2de refactor a bit for better grpo support 2025-02-04 11:06:23 -05:00
Wing Lian
57264b6491 respect dotenv for cli 2025-02-04 11:06:23 -05:00
Wing Lian
d495e41ba1 refactor dpo trainer into own module 2025-02-04 11:06:23 -05:00
Wing Lian
6067fe6c28 upgrade trl to 0.14.0 2025-02-04 11:06:23 -05:00
NanoCode012
a620d481e2 fix: drop long seq even if not sample packing (#2211)
* fix: drop long seq even if not sample packing

* fix: logging import

* fix: cfg passed being none

* fix: try to fix logging

* fix: refactor call to not use accelerate log

* fix: try to fix circular import issue

* fix: don't drop when skip prepare

* chore: remove duplicate line

* fix: update warning to mention that sequences will be trimmed

* fix: do not drop seq if input_ids don't exist

* fix: increase RM unittest sequence length to reduce trim warnings

* fix: solve conflicts

* fix: default min_seq_len in case of None
2025-02-04 09:43:35 -05:00
52 changed files with 593 additions and 7653 deletions

View File

@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.48.1
transformers==4.48.2
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.13.0
trl==0.14.0
optimum==1.16.2
hf_transfer
@@ -26,7 +26,7 @@ sentencepiece
gradio==3.50.2
modal==0.70.5
pydantic==2.6.3
pydantic==2.10.6
addict
fire
PyYAML>=6.0

View File

@@ -153,5 +153,8 @@ setup(
"ray": [
"ray[train]",
],
"vllm": [
"vllm>=0.7.1",
],
},
)

View File

@@ -35,13 +35,18 @@ def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cwd=None,
**kwargs,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
def do_cli_lm_eval(

View File

@@ -7,6 +7,7 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Optional
import modal
@@ -22,8 +23,18 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
paths = ["/workspace/mounts"]
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
sub_python_path = Path(sub_python_path_str)
if not sub_python_path.joinpath("src", "axolotl").exists():
# we don't want to use the automounted axolotl or unexpected behavior happens
paths.append(str(sub_python_path))
if paths:
new_env["PYTHONPATH"] = ":".join(paths)
else:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
@@ -203,9 +214,12 @@ class ModalCloud(Cloud):
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
def get_train_env(self, local_dirs=None):
image = self.get_image()
for mount, local_dir in (local_dirs or {}).items():
image = image.add_local_dir(local_dir, mount)
return self.app.function(
image=self.get_image(),
image=image,
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
@@ -214,14 +228,21 @@ class ModalCloud(Cloud):
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
def train(
self,
config_yaml: str,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
def lm_eval(self, config_yaml: str):
@@ -252,7 +273,7 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
@@ -262,8 +283,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -1,135 +0,0 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import PluginManager
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
LinearLlamaConfig,
)
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
LinearLlamaForCausalLM,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
from axolotl.utils.trainer import setup_trainer
LOG = logging.getLogger(__name__)
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Convert attention to linear attention and perform attention transfer via distillation.
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
cfg.load_in_8bit = False
cfg.load_in_4bit = False
cfg.adapter = None
# load model
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
# freeze model
for p in model.parameters():
p.requires_grad = False
# convert to linear llama
linear_llama_config = LinearLlamaConfig.from_llama(
model.config, cfg.attention_config
)
model = LinearLlamaForCausalLM.from_llama(
model, config=linear_llama_config, train_attention=True
)
# set save_path, save tokenizer and model config.
save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path)
if hasattr(model, "config"):
model.config.save_pretrained(save_path)
# Get datasets
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# toggle attention to be trainable
model.toggle_attention(train=True)
# Setup trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=(model, None, None),
tokenizer=tokenizer,
processor=None,
total_num_steps=total_num_steps,
)
# train
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
# drop base_attention + remove training attn
model.toggle_attention(train=False)
model.remove_base_attention()
# NOTE: If in peft mode, consider whether to auto-merge
# save model
safe_serialization = cfg.save_safetensors is True
# NOTE: may need to consider other ways of saving due to multi-gpu etc
model.save_pretrained(save_path, safe_serialization=safe_serialization)
# cleanup
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# load cfg, force linearize and add plugin to linearize
parsed_cfg = load_cfg(
config,
linearize=True,
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
**kwargs,
)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_linearize(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -2,6 +2,7 @@
# pylint: disable=redefined-outer-name
import logging
import os
import random
import subprocess # nosec B404
import tempfile
@@ -12,6 +13,7 @@ from typing import Optional
import click
import yaml
from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
@@ -199,7 +201,10 @@ def train(
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud, config=config, accelerate=True, cwd=cwd, **kwargs
)
else:
accelerate_args = []
if "main_process_port" in kwargs:
@@ -208,7 +213,7 @@ def train(
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
@@ -220,7 +225,9 @@ def train(
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli
@@ -381,4 +388,5 @@ def main():
if __name__ == "__main__":
load_dotenv()
main()

View File

@@ -122,9 +122,11 @@ def load_preference_datasets(
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl == "grpo":
total_num_steps = None
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")

View File

@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
AxolotlCPOTrainer,
AxolotlDPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
AxolotlORPOTrainer,
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlDPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
AxolotlPRMConfig,
@@ -652,7 +653,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
@@ -965,10 +966,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
@@ -977,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
@@ -1001,11 +1004,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
@@ -1016,11 +1023,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[
"use_logits_to_keep"
] = self.cfg.dpo_use_logits_to_keep
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
max_steps = self.cfg.max_steps or total_num_steps or -1
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
max_steps=max_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
warmup_steps=self.cfg.warmup_steps,
@@ -1047,8 +1063,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
@@ -1068,7 +1088,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -5,30 +5,21 @@ module for customized trainers
from __future__ import annotations
# pylint: disable=too-many-lines
import gc
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
from typing import Dict, Literal, Optional
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOTrainer,
DPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers

View File

@@ -0,0 +1,33 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""
Strategy for DPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlDPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
return AxolotlDPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
return training_args_kwargs

View File

@@ -0,0 +1,15 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""

View File

@@ -0,0 +1,125 @@
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any, Dict, Union
import torch
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.base import (
SchedulerMixin,
_sanitize_kwargs_for_ds_tagging,
_sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
# pylint: disable=duplicate-code
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
# pylint: disable=duplicate-code
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss

View File

@@ -0,0 +1,113 @@
"""
GRPO Specific Strategy for training
"""
import importlib
import inspect
import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
LOG = logging.getLogger("axolotl")
class GRPOStrategy:
"""
Strategy for GRPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
return AxolotlGRPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
return grpo_args_kwargs
@classmethod
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
for reward_func_fqn in cfg.trl.reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_kwargs["reward_funcs"] = reward_funcs
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
# No data collation is needed in GRPO, handled by trl's trainer __init__
return None
@classmethod
def get_blocklist_args_kwargs(cls):
return ["dataset_num_proc"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
"""
Returns the reward function from the given fully qualified name, or the path to the reward function model.
Args:
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
or a HF hub path to the reward model.
Raises:
ValueError: If the reward function does not accept at least two arguments.
Returns:
RewardFunc: A callable that accepts prompts and completions and returns rewards,
or a path to a reward model.
"""
try:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError(
"Reward function must accept at least two arguments: prompts: list and completions: list"
)
return reward_func
except ModuleNotFoundError:
# the user has passed a string (ideally indicating the path of a reward model)
LOG.info(
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func

View File

@@ -0,0 +1,15 @@
"""
Axolotl Specific Training Args
"""
from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -0,0 +1,14 @@
"""
Axolotl GRPO trainer
"""
from trl import GRPOTrainer
from axolotl.core.trainers.base import SchedulerMixin
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""

View File

@@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,44 +0,0 @@
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
https://github.com/HazyResearch/lolcats/
### Usage
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
```bash
cd src/axolotl/integrations/lolcats/linear_llama/csrc
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
# nano setup.py
# Build the CUDA kernel
python setup.py install
```
Step 1:
```yaml
plugins:
- axolotl.integrations.lolcats.LinearizePlugin
linearize: true
```
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
```yaml
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
# with optional config below but this requires patching axolotl
# to allow this config to work with lora
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
```
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
Step 3: Run inference on the finetuned model
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`

View File

@@ -1,43 +0,0 @@
"""
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
Low-rank Linear Conversion via Attention Transfer
"""
import logging
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
DistillAttentionXentMSETrainer,
)
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.lolcats")
class LinearizePlugin(BasePlugin):
"""
Plugin for lolcats integration with Axolotl.
"""
def __init__(self):
super().__init__()
# Register the Linear Llama model with transformers
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
register_linear_llama,
)
register_linear_llama()
def get_input_args(self):
return "axolotl.integrations.lolcats.LinearAttentionArgs"
def get_trainer_cls(self, cfg):
# defualt to XentMSE
# TODO: add check to allow MSE_linear
if cfg.linearize:
return DistillAttentionXentMSETrainer
return None

View File

@@ -1,47 +0,0 @@
"""
Module for handling linear attention input arguments.
"""
from typing import Optional
from pydantic import BaseModel
class FeatureMapKwargs(BaseModel):
"""Args for feature map"""
eps: float
mlp: Optional[None] = None
fullspace: bool
class LearnedKernelKwargs(BaseModel):
"""Args for learned kernel"""
feature_dim: int
skip_connection: bool
bias: bool
zero_init: bool
class AttentionConfig(BaseModel):
"""Args for attention config"""
attention_type: str
feature_map: str
feature_map_kwargs: FeatureMapKwargs
layer_idx: Optional[None] = None
learned_kernel: str
learned_kernel_kwargs: LearnedKernelKwargs
tie_qk_kernels: bool
train_qk: bool
class LinearAttentionArgs(BaseModel):
"""
Input args for linear attention
"""
attention_config: AttentionConfig
linearize: Optional[bool] = False

View File

@@ -1,90 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear LLaMA model configuration"""
from typing import Optional
from transformers import LlamaConfig
class LinearLlamaConfig(LlamaConfig):
"""
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
It is a modified LlamaConfig that includes additional parameters for linear attention.
Args:
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
Expected contents:
`attention_type` (str):
The type of attention to convert to.
`feature_map` (`str`):
The type of feature map to use for linear attention.
`feature_map_kwargs` (`dict`):
Additional arguments for the feature map.
`learned_kernel` (`str`, *optional*):
Type of learned kernel to use, if any.
`learned_kernel_kwargs` (`dict`, *optional*):
Additional arguments for the learned kernel.
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
Whether to tie query and key kernels.
`rotary_config` (`dict`, *optional*):
Configuration for rotary embeddings.
`train_attention` (`bool`, *optional*, defaults to False):
Whether to train attention to match softmax attention.
`remove_base_attn` (`bool`, *optional*, defaults to True):
Whether to remove base attention after initialization.
`mask_value` (`int`, *optional*, defaults to 0):
Value to use for masking.
`eps` (`float`, *optional*, defaults to 1e-12):
Epsilon value for numerical stability.
`fp32_attention` (`bool`, *optional*, defaults to False):
Whether to use fp32 precision for attention computation.
`track_state_grads` (`bool`, *optional*, defaults to False):
Whether to track gradients of attention states.
**kwargs:
Additional arguments inherited from LlamaConfig.
"""
model_type = "linear_llama"
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
super().__init__(**kwargs)
# Set auto_map
self.auto_map = {
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
}
# Set default attention config if none provided
self.attention_config = attention_config or {"attention_type": "softmax"}
@classmethod
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
"""
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
Args:
llama_config (:class:`~transformers.LlamaConfig`):
The LlamaConfig to inherit from.
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
"""
return cls(attention_config=attention_config, **llama_config.to_dict())

View File

@@ -1,30 +0,0 @@
# Causal linear attention CUDA kernel
Usage:
```bash
cd src/axolotl/integrations/lolcats/linear_llama/csrc
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
# nano setup.py
# Build the CUDA kernel
python setup.py install
```
Reference: https://github.com/idiap/fast-transformers/
```bib
@inproceedings{katharopoulos_et_al_2020,
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
year = {2020}
}
@article{vyas_et_al_2020,
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
title={Fast Transformers with Clustered Attention},
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
year={2020}
}
```

View File

@@ -1,6 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
from .causal_attention import causal_dot_product

View File

@@ -1,225 +0,0 @@
//
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
// Apoorv Vyas <avyas@idiap.ch>
//
#include <torch/extension.h>
/**
* Compute a*b^T and save it into out.
*
* a \in R^A
* b \in R^B
*/
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
for (int i=0; i<A; i++) {
float * bi = b;
for (int j=0; j<B; j++) {
*out += (*a) * (*bi);
out++;
bi++;
}
a++;
}
}
/**
* Implement a vector matrix product v*m and save it into out.
*
* v \in R^A
* m \in R^{AxB}
*/
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
// TODO: Consider removing the zeroing part and assuming out already
// contains 0s
for (int i=0; i<B; i++) {
out[i] = 0;
}
for (int i=0; i<A; i++) {
float *oi = out;
for (int j=0; j<B; j++) {
*oi += (*v) * (*m);
oi++;
m++;
}
v++;
}
}
/**
* Implement a vector transposed-matrix product and save it into out.
*
* v \in R^B
* m \in R^{AxB}
*/
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
for (int i=0; i<A; i++) {
float *vi = v;
float s = 0;
for (int j=0; j<B; j++) {
s += (*vi) * (*m);
vi++;
m++;
}
// TODO: Should we be aggregating? See the comment on vm_dot.
*out = s;
out++;
}
}
/**
* Compute the causally masked dot products of queries, keys and values.
*
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
* computation is done efficiently by changing the order of the dot products.
*/
void causal_dot_product(
const torch::Tensor queries,
const torch::Tensor keys,
const torch::Tensor values,
torch::Tensor product
) {
// Extract some shapes
int N = queries.size(0);
int H = queries.size(1);
int L = queries.size(2);
int E = queries.size(3);
int M = values.size(3);
// Create accessors for all the arguments
auto qa = queries.accessor<float, 4>();
auto ka = keys.accessor<float, 4>();
auto va = values.accessor<float, 4>();
auto pa = product.accessor<float, 4>();
#pragma omp parallel for collapse(2)
for (int n=0; n<N; n++) {
for (int h=0; h<H; h++) {
auto kv = torch::zeros({E, M}, queries.options());
float *kvp = kv.data_ptr<float>();
for (int l=0; l<L; l++) {
vvt_dot(
&ka[n][h][l][0],
&va[n][h][l][0],
kvp,
E,
M
);
vm_dot(
&qa[n][h][l][0],
kvp,
&pa[n][h][l][0],
E,
M
);
}
}
}
}
/**
* Compute the gradients of queries, keys and values given the gradient of the
* causal_dot_product output.
*
* Make sure that everything is computed in O(N D^2) complexity.
*/
void causal_dot_backward(
const torch::Tensor queries,
const torch::Tensor keys,
const torch::Tensor values,
const torch::Tensor grad_out,
torch::Tensor grad_queries,
torch::Tensor grad_keys,
torch::Tensor grad_values
) {
// Extract some shapes
int N = queries.size(0);
int H = queries.size(1);
int L = queries.size(2);
int E = queries.size(3);
int M = values.size(3);
// Create accessors for all the arguments
auto qa = queries.accessor<float, 4>();
auto ka = keys.accessor<float, 4>();
auto va = values.accessor<float, 4>();
auto ga = grad_out.accessor<float, 4>();
auto gqa = grad_queries.accessor<float, 4>();
auto gka = grad_keys.accessor<float, 4>();
auto gva = grad_values.accessor<float, 4>();
#pragma omp parallel for collapse(2)
for (int n=0; n<N; n++) {
for (int h=0; h<H; h++) {
auto kv = torch::zeros({E, M}, queries.options());
float *kvp = kv.data_ptr<float>();
// Compute the gradient wrt the queries
for (int l=0; l<L; l++) {
vvt_dot(
&ka[n][h][l][0],
&va[n][h][l][0],
kvp,
E,
M
);
vmt_dot(
&ga[n][h][l][0],
kvp,
&gqa[n][h][l][0],
E,
M
);
}
// Compute the gradient wrt the keys and values
kv.zero_();
for (int l=L-1; l>=0; l--) {
vvt_dot(
&qa[n][h][l][0],
&ga[n][h][l][0],
kvp,
E,
M
);
vmt_dot(
&va[n][h][l][0],
kvp,
&gka[n][h][l][0],
E,
M
);
vm_dot(
&ka[n][h][l][0],
kvp,
&gva[n][h][l][0],
E,
M
);
}
}
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"causal_dot_product",
&causal_dot_product,
"Compute the weighted sum of values but attending only to previous "
"values."
);
m.def(
"causal_dot_backward",
&causal_dot_backward,
"Compute the gradient of queries, keys and values given the gradient "
"of causal_dot_product."
);
}

View File

@@ -1,67 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import torch
try:
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
except ImportError as e:
print(e)
causal_dot_product_cuda = causal_dot_backward_cuda = None
class CausalDotProduct(torch.autograd.Function):
"""Compute the weighted sum of values but attending only to previous
values."""
dot = {
# "cpu": causal_dot_product_cpu,
"cuda": causal_dot_product_cuda
}
dot_backward = {
# "cpu": causal_dot_backward_cpu,
"cuda": causal_dot_backward_cuda
}
@staticmethod
def forward(ctx, Q, K, V):
# Save the inputs for the gradient computation
ctx.save_for_backward(Q, K, V)
# Create the output tensor
device = Q.device
N, H, L, _ = Q.shape
_, _, _, M = V.shape
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
# Actually perform the dot product
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
# breakpoint()
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
return product
@staticmethod
def backward(ctx, grad_out):
# Extract the saved tensors
Q, K, V = ctx.saved_tensors
# Allocate memory for the gradients
grad_Q = torch.zeros_like(Q)
grad_K = torch.zeros_like(K)
grad_V = torch.zeros_like(V)
# Actually compute the gradients
CausalDotProduct.dot_backward[Q.device.type](
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
)
return grad_Q, grad_K, grad_V
# Alias the autograd functions to python style snake case naming
causal_dot_product = CausalDotProduct.apply

View File

@@ -1,65 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import subprocess # nosec
import torch
from setuptools import setup
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
def get_last_arch_torch():
arch = torch.cuda.get_arch_list()[-1]
print(f"Found arch: {arch} from existing torch installation")
return arch
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
arch = get_last_arch_torch()
sm_num = arch[-2:]
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
setup(
name="causal_attention_cuda_cpp",
ext_modules=[
CUDAExtension(
"causal_attention_cuda",
[
# 'causal_attention.cpp',
"causal_attention_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
),
},
)
],
cmdclass={"build_ext": BuildExtension},
)

View File

@@ -1,856 +0,0 @@
"""
Linear attention classes
"""
import copy
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
# Causal linear attention dot product CUDA kernel from fast-transformers
try:
from csrc import causal_dot_product as fast_causal_dot_product
except ImportError:
fast_causal_dot_product = None
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
# -------------------
# Attention functions
# -------------------
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
Causal linear attention dot product
- If available, use CUDA kernel from fast-transformers
"""
if fast_causal_dot_product is None:
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
return fast_causal_dot_product(q, k, v)
def linear_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
fp32_attention: bool = False,
eps: float = 1e-12,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Compute linear attention with CUDA kernel implementation from fast-transformers
- https://github.com/idiap/fast-transformers
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
v is shape (b, h, l, head_dim)
"""
dtype = q.dtype
# Causal mask already applied
y = causal_dot_product(
q.contiguous().to(dtype=torch.float32),
k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32),
)
if fp32_attention:
y = (
y
/ (
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
)[..., None]
).to(dtype=dtype)
else:
y = y.to(dtype=dtype)
k = k.float().cumsum(dim=2).to(dtype=dtype)
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
return y, None, None
def softmax_attention(
q: torch.Tensor,
k: torch.Tensor,
v: Optional[torch.Tensor] = None,
causal: bool = True,
fp32_attention: bool = True,
):
"""
Standard softmax attention; only compute outputs if v is not None
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
"""
y = None
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
n - m + 1
)
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
if fp32_attention:
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
else:
a = torch.softmax(a, dim=-1)
if v is not None:
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
return y, a, None
def quadratic_attention(
q: torch.Tensor,
k: torch.Tensor,
v: Optional[torch.Tensor] = None,
causal: bool = True,
fp32_attention: bool = False,
eps: float = 1e-12,
):
"""
Compute attention with feature maps by instantiating L x L matrix of attention weights
-> Use for attention distillation
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
"""
y = None
dtype = q.dtype
if fp32_attention:
q, k = q.float(), k.float()
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
n - m + 1
)
a = a.masked_fill(causal_mask, 0)
# Normalize to compute attention
a = a / (a.sum(dim=-1, keepdim=True) + eps)
a = a.to(dtype=dtype) if fp32_attention else a
if torch.isnan(a).sum() > 0:
breakpoint()
if v is not None:
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
return y, a, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearAttention(nn.Module):
"""
LoLCATs attention implementation initialized from a
`LlamaAttention` or `MistralAttention` object (base_attn)
Most of the arguments are directly tied to argparse args
- For now we don't support padding.
"""
def __init__(
self,
base_attn: nn.Module, # like LlamaAttention
feature_map: str,
feature_map_kwargs: dict,
layer_idx: Optional[int] = None,
max_layer_idx: Optional[int] = None,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
tie_qk_kernels: Optional[bool] = False,
rotary_config: Optional[dict] = None,
train_attention: Optional[bool] = False,
remove_base_attn: bool = True,
attention_type: Optional[str] = "lolcats_llama",
mask_value: int = 0,
eps: float = 1e-12,
fp32_attention: bool = False,
track_state_grads: bool = False,
rank: Optional[int] = 0,
**kwargs,
) -> None:
super().__init__()
self.base_config = getattr(base_attn, "config", None)
if self.base_config is not None:
self.base_config = self.base_config.to_dict()
self.attention_type = attention_type
self.mask_value = mask_value
self.eps = eps
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
self.max_layer_idx = max_layer_idx
self.tie_qk_kernels = tie_qk_kernels
self.train_attention = train_attention
self.base_inference = False
self.fp32_attention = fp32_attention
self.track_state_grads = track_state_grads
if rank == 0: # multi-gpu
if fp32_attention and layer_idx == 0:
print(f"-> fp32_attention is {fp32_attention}")
if layer_idx == 0 and feature_map_kwargs is not None:
for k, v in feature_map_kwargs.items():
print(f"-> {k}: {v}")
if layer_idx == 0 and learned_kernel_kwargs is not None:
for k, v in learned_kernel_kwargs.items():
print(f"-> {k}: {v}")
self.remove_base_attn = remove_base_attn
self.init_weights_(base_attn, remove_base_attn)
self.init_feature_map_(
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
)
def init_feature_map_(
self,
feature_map: str,
feature_map_kwargs: dict,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
):
"""
Initialize MLP-based feature map
"""
self.fmap_gqa = False # Turn True if specified below
if learned_kernel is not None and learned_kernel_kwargs is not None:
# Ensure dict
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
learned_kernel_kwargs["num_heads"] = self.num_heads
learned_kernel_kwargs["head_dim"] = self.head_dim
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
learned_kernel_kwargs["device"] = self.q_proj.weight.device
# Create MLP
mlp_learned_kernel = init_learned_kernel(
learned_kernel, **learned_kernel_kwargs
)
# Add "activation"; see src.models.feature_map.py
self.feature_map_q = init_feature_map(
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
)
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
self.feature_map_k = self.feature_map_q
else:
self.feature_map_k = copy.deepcopy(self.feature_map_q)
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
"""
Initialize module layers, weights, positional dependencies, etc.
from original softmax attention layer (base_attn)
"""
# Make other attributes accessible
self.attention_dropout = 0 # We don't use dropout
self.hidden_size = base_attn.config.hidden_size
self.num_heads = base_attn.config.num_attention_heads
self.head_dim = base_attn.head_dim
self.num_key_value_heads = base_attn.config.num_key_value_heads
self.num_key_value_groups = base_attn.num_key_value_groups
self.q_shape = [self.num_heads, self.head_dim]
self.k_shape = [self.num_key_value_heads, self.head_dim]
self.v_shape = [self.num_key_value_heads, self.head_dim]
# Copy original model projection layers
self.q_proj = base_attn.q_proj
self.k_proj = base_attn.k_proj
self.v_proj = base_attn.v_proj
self.o_proj = base_attn.o_proj
try: # If wanting to use FA2 for ground-truth inference
self._flash_attn_uses_top_left_mask = (
base_attn._flash_attn_uses_top_left_mask
)
except AttributeError:
pass
if self.remove_base_attn or remove_base_attn:
del base_attn # We don't need to keep these around
else:
self.base_attn = base_attn # For some training runs helpful to just call
def process_qkv(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None,
):
"""
Compute queries, keys, and values
"""
b, l, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
kv_seq_len = k.shape[-2]
# Shape is (batch_size, seq_len, num_heads, head_dim)
q = q.view(b, l, *self.q_shape).transpose(1, 2)
k = k.view(b, l, *self.k_shape).transpose(1, 2)
v = v.view(b, l, *self.v_shape).transpose(1, 2)
if (
past_key_value is not None
): # and k.shape[2] > q.shape[2]: # e.g., when generating
past_key_value.window_size = getattr(
self, "decode_window_size", None
) # self.decode_window_size
if isinstance(
past_key_value, Cache
): # In Transformers v4.36+ this is a DynamicCache object
kv_seq_len += past_key_value.get_usable_length(
kv_seq_len, self.layer_idx
)
else:
kv_seq_len += past_key_value[0].shape[-2]
# Apply rotary embeddings
if position_embeddings is not None:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
return q, k, v, kv_seq_len
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None, # "legacy" cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
- Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_embeddings, past_key_value
)
if self.base_inference:
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
y_true, _, _ = softmax_attention(q, k, v, causal=True)
y_true = (
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
attn_weights = (None, None)
elif self.train_attention: # Distilling / learning attentions
# Note for now we assume no padding when distilling; attention masks only enforce causality
assert (
output_attentions is True
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention (just weights)
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
attn_weights = ( # type: ignore
(attn_pred, attn_true),
(y_pred, _y_true),
) # Save both attention weights so we can supervise.
else: # Finetuning
q, k = self.feature_map_q(q), self.feature_map_k(k)
# Apply prefill mask
if attention_mask is not None and q.shape[2] > 1:
if len(attention_mask.shape) == 4:
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
..., None
] # b, 1, k_len, 1
else:
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
k = k.masked_fill(~lin_attn_mask, 0)
if past_key_value is not None: # Initialize states
if len(past_key_value.kv_states) == self.layer_idx:
b, h, _, f = k.shape
past_key_value.kv_states.append(
torch.zeros(
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
)
)
past_key_value.k_states.append(
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
)
# Generating
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
assert use_cache is True
kv_state, k_state = past_key_value.update(
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
)
if self.fp32_attention:
q = q.float()
y_true = (
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
..., None
]
).to(dtype=k.dtype)
else:
y_true = (
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
)
else:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
y_true, _, _ = linear_attention(
q, k, v, self.fp32_attention, self.eps
) # Ordinarily the states are ignored
past_key_value.update(
k.detach(),
v.detach(),
self.layer_idx,
accumulate_in_fp32=self.fp32_attention,
)
# doing some unnecessary recomputation here
else:
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
attn_weights = None
return y_true, attn_weights
class LinearAttentionState(Cache):
"""
Handle the KV and K states for linear attention
- Adopts HF Transformers `past_key_values` convention
- Inherits from `Cache` class
- Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self) -> None:
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states. A layer index can be optionally passed.
"""
if layer_idx is None:
raise ValueError("Layer index must not be None")
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
self._seen_tokens_by_layer.append(0)
return self._seen_tokens_by_layer[layer_idx]
def get_max_length(self) -> Optional[int]:
"""
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
"""
return None
def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.no_grad():
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
key_states, value_states = key_states.float(), value_states.float()
kv_state = torch.einsum(
"bhlf,bhld->bhfd", key_states, value_states
).detach()
k_state = key_states.sum(
dim=-2, keepdim=True
).detach() # b, h, 1, f; note the 1
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
print(
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
)
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
else:
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def to_legacy_cache(self):
"""Hack, but just return self"""
return self
def reorder_cache(self, beam_idx: torch.LongTensor):
"""
Reorders the cache for beam search, given the selected beam indices.
-> Copied from transformers/src/transformers/cache_utils.py
"""
raise NotImplementedError(
"Reordering cache not implemented for LinearAttentionState"
)
# -------------------
# feature map functions
# -------------------
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
if name == "softmax_dim" and fullspace:
return SoftmaxDim(**kwargs)
elif name == "softmax_dim" and not fullspace:
return SoftmaxDimHalfspace(**kwargs)
elif name == "exp_dim" and fullspace:
return Exp(**kwargs)
elif name == "exp_dim" and not fullspace:
return ExpHalfspace(**kwargs)
elif name == "pos_elu":
return PosELU(**kwargs)
elif name == "relu":
return ReLU(**kwargs)
else:
raise NotImplementedError
def init_learned_kernel(name: str, **kwargs):
"""
Initialize feature map MLP for linear attention
"""
if name == "untied_head_einsum":
return FeatureMapMLP(**kwargs)
elif name == "untied_head_adapter":
return FeatureMapAdapter(**kwargs)
else:
raise NotImplementedError
class FeatureMap(nn.Module):
"""
Final 'activation' of feature map. Can probably be combined with
`FeatureMapMLP` below
Full feature map is like f(xW + b)
-> This is the `f` part
"""
def __init__(
self,
activation_name: str,
head_dim_idx: int = -1,
eps: float = 1e-12,
mlp: Optional[nn.Module] = None,
fullspace: bool = True,
):
super().__init__()
self.head_dim_idx = head_dim_idx
self.eps = eps
self.mlp = mlp if mlp is not None else nn.Identity()
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
"""
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
def q_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
def k_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
# -----------------------
# Feature map activations
# -----------------------
class FeatureMapAct(nn.Module):
"""
Base class for feature map activations
"""
def __init__(self, eps: float = 1e-12):
super().__init__()
self.eps = eps
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return x
class PosELU(FeatureMapAct):
"""
1 + ELU activation as in https://arxiv.org/abs/2006.16236
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return (1 + F.elu(x)).clamp(min=self.eps)
class ReLU(FeatureMapAct):
"""
ReLU activation as in https://arxiv.org/abs/2103.13076
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return F.relu(x).clamp(min=self.eps)
class SoftmaxDim(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.cat(
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
).clamp(min=self.eps)
class SoftmaxDimHalfspace(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.softmax(x, dim=-1).clamp(min=self.eps)
class Exp(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
x_min = torch.amin(x, dim=-1, keepdim=True)
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
min=self.eps
)
class ExpHalfspace(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
return torch.exp(x - x_max).clamp(min=self.eps)
# ----------------
# Feature map MLPs
# ----------------
class FeatureMapMLP(nn.Module):
"""
Learnable MLP in feature map.
Full feature map is like f(xW + b)
-> This is the `W` and (optional) `b` part
"""
def __init__(
self,
num_heads: int,
head_dim: int, # input dim
feature_dim: int, # output dim
dtype: torch.dtype,
device: torch.device,
skip_connection: bool = False,
bias: bool = False,
zero_init: bool = False,
normal_init: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.feature_dim = feature_dim
self.dtype = dtype
self.device = device
self.skip_connection = skip_connection
self.bias = bias
self.zero_init = zero_init
self.normal_init = normal_init
self.init_weights_()
if self.zero_init: # Zero-out weights or set as identity post-initialization
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
if self.normal_init:
with torch.no_grad():
nn.init.normal_(self.layer)
if self.skip_connection:
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
assert self.head_dim == self.feature_dim, assertion_fail
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
self.layer = nn.Parameter(
torch.zeros(
(self.num_heads, self.head_dim, self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.layer)
if self.bias:
self.bias = nn.Parameter(
torch.zeros(
(1, self.num_heads, 1, 1), # self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.bias)
else:
self.bias = 0.0 # hack
def zero_init_with_skip_(self):
"""
Initialize weights to zero matrix if skip connection
"""
with torch.no_grad():
nn.init.zeros_(self.layer)
def zero_init_(self):
"""
Initialize weights to identity matrix if no skip connection
"""
with torch.no_grad():
for i in range(self.layer.shape[0]):
try:
nn.init.eye_(self.layer[i])
except RuntimeError:
with torch.no_grad():
dtype = self.layer[i].dtype
weight = torch.eye(
*self.layer[i].shape,
requires_grad=self.layer[i].requires_grad,
device=self.layer[i].device,
)
self.layer[i] = weight.to(dtype=dtype)
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
"""
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
return x + _x if self.skip_connection else _x
class FeatureMapAdapter(FeatureMapMLP):
"""
Learnable Feature map with bottleneck adapter
as in https://arxiv.org/abs/1902.00751
We don't use but could be fun to try
"""
def __init__(self, hidden_dim: int, *args, **kwargs):
kwargs["skip_connection"] = True
kwargs["bias"] = True
kwargs["zero_init"] = True
self.hidden_dim = hidden_dim
super().__init__(*args, **kwargs)
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
kwargs = {"dtype": self.dtype, "device": self.device}
self.layer0 = nn.Parameter(
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
)
self.layer1 = nn.Parameter(
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.layer0)
nn.init.kaiming_uniform_(self.layer1)
self.bias0 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
)
self.bias1 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.bias0)
nn.init.kaiming_uniform_(self.bias1)
def zero_init_with_skip_(self):
with torch.no_grad():
nn.init.zeros_(self.layer0)
nn.init.zeros_(self.layer1)
nn.init.zeros_(self.bias0)
nn.init.zeros_(self.bias1)
def zero_init_(self):
raise NotImplementedError
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
-> Down-project, apply nonlinearity, up-project; add skip connection
"""
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
_x = F.relu(_x)
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
return x + _x if self.skip_connection else _x

View File

@@ -1,460 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using "standard" sliding windows
- Didactically computes outputs with n^2 attention weights for now
- Copied + adapted from linear_window_attention_tk.py for single-file reference
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
softmax_attention,
)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
k_len - q_len
)
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
k_len - q_len - window_size
)
window_mask = causal_mask - linear_mask
# Return softmax mask (window), linear attention mask
# -> shapes broadcast over (b, h, q_len, k_len)
return window_mask[None, None, ...], linear_mask[None, None, ...]
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ---------------------
# Attention layer class
# ---------------------
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
# Determine how we compute attentions
self.quadratic_attention = hybrid_attention_quadratic
self.attention_type = kwargs[
"attention_type"
] # 'hedgehog_long_llama_window_sw'
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
if self.train_attention:
# 1. Compute "ground-truth" attention output and weights
with torch.no_grad():
_y_true, a_true = softmax_attention(q, k, v)[:2]
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention outputs
# compute attn weights under sliding window
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
else:
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, attn_weights, past_key_value
class LinearAttentionSlidingWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states must not be None")
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
# else:
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,685 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using "standard" sliding windows
- Didactically computes outputs with n^2 attention weights for now
- Copied + adapted from linear_window_attention_tk.py for single-file reference
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import logging
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# Causal linear attention dot product CUDA kernel from fast-transformers
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
causal_dot_product,
)
LOG = logging.getLogger(__name__)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
max(k_len - q_len, 0)
)
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
max(k_len - q_len, 0) - window_size
)
window_mask = causal_mask - linear_mask
# Return softmax mask (window), linear attention mask
# -> shapes broadcast over (b, h, q_len, k_len)
return window_mask[None, None, ...], linear_mask[None, None, ...]
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ------------------------------
# Hybrid window attention linear
# ------------------------------
def under_window_linear_attention(
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_size: int,
linear_factor: torch.Tensor,
eps: float = 1e-12,
):
"""Compute hybrid window attention dot product with linear complexity in q_len"""
dtype = f_q.dtype
w = window_size
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
qkv = linear_factor * causal_dot_product(
f_q.contiguous().to(dtype=torch.float32),
f_k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32),
).to(dtype=dtype)
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
sum_qk[sum_qk == 0] += eps
return qkv, sum_qk
def sliding_window_softmax_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
window_size: int,
window_factor: torch.Tensor,
mask_value: float = -1e8,
):
"""
Compute sliding window softmax attention without materializing
O(seq_len^2) attention weights
"""
d = q.shape[-1]
# Compute windows for keys
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
# Compute windowed_softmax(qk); causal in its construction
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
a_sm[a_sm == 0] = -torch.finfo(
q.dtype
).max # heuristic for zeroing out padding above
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
def hybrid_attention_linear(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: Optional[torch.Tensor] = None,
linear_factor: Optional[torch.Tensor] = None,
window_size: int = 64,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Alternative hybrid attention combining sliding window and linear attentions
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
"""
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
if window_factor is None:
raise ValueError("window_factor must be provided")
if linear_factor is None:
raise ValueError("linear_factor must be provided")
# 1. Sliding window (softmax attention)
with torch.no_grad():
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
q, k, v, window_size, window_factor, mask_value
)
# 2. Under window (linear attention)
qkv_ln, sum_qk_ln = under_window_linear_attention(
f_q, f_k, v, window_size, linear_factor, eps
)
# 3. Combine
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
return y, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
# Determine how we compute attentions
self.linear_attention = hybrid_attention_linear
self.attention_type = "lolcats_llama_window_sw"
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
if self.train_attention and self.base_inference:
with torch.no_grad():
_y_true = flash_attention_2(
self, # self.base_attn,
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
)[0]
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
y_true = _y_true.reshape(b, l, -1).contiguous()
y_true = self.o_proj(y_true)
# layer_io = (hidden_states, _y_true) # hack
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
return y_true, layer_io, None
else:
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.linear_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.linear_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
_y_true = y_true.transpose(1, 2).contiguous()
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
if self.train_attention:
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
return y_true, attn_weights, past_key_value
class LinearAttentionSlidingWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states must not be None")
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
# else:
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)
# -----------------
# Flash Attention 2
# -----------------
def flash_attention_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
"""
Wrapper for LlamaFlashAttention2
Copied and modified from HF Transformers v4.36 and v4.43 implementations
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
try: # As in Transformers v4.36
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
except Exception: # As in Transformers v4.39
cos, sin = self.rotary_emb(key_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
LOG.debug(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self, "_flash_attention_forward", False):
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=True,
)
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0, # dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
return attn_output, past_key_value

View File

@@ -1,24 +0,0 @@
"""
LoLCATs attention combining sliding window and linear attentions
- Using standard sliding window arrangement
- Training over long sequences with fixed memory with recurrent view
- During attention transfer, use Flash Attention to compute softmax attention outputs
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from .linear_window_attention_sw import hybrid_attention_quadratic
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(self, remove_base_attn=True, **kwargs):
# keep self.base_attn for Flash Attention inference
super().__init__(remove_base_attn=True, **kwargs)
self.quadratic_attention = hybrid_attention_quadratic

View File

@@ -1,466 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using the TK "terracing" arrangement
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import math
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
softmax_attention,
)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
win_len = window_size
m = math.ceil(max(q_len, k_len) / window_size)
# Creates an n x n mask where n = window_size^2
mask = torch.block_diag(
*[
torch.ones(
(win_len, win_len),
)
]
* m
)
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
if mask.shape[0] > q_len:
mask = mask[-q_len:]
if mask.shape[1] > k_len:
mask = mask[:, -k_len:]
# Return softmax mask (window), linear attention mask
mask = mask[None, None, ...] # b, h, q_len, k_len
return (
torch.tril(mask).to(device=device, dtype=torch.int),
torch.tril(1 - mask).to(device=device, dtype=torch.int),
)
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ---------------------
# Attention layer class
# ---------------------
class LolcatsTKWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
# Determine how we compute attentions
self.quadratic_attention = hybrid_attention_quadratic
self.attention_type = kwargs[
"attention_type"
] # 'hedgehog_long_llama_window_tk'
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
self.window_factor = self.window_factors # legacy naming support
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
if self.train_attention:
# 1. Compute "ground-truth" attention output and weights
with torch.no_grad():
_y_true, a_true = softmax_attention(q, k, v)[:2]
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention outputs
# compute attn weights under sliding window
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
else:
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, attn_weights, past_key_value
class LinearAttentionTKWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states should not be None")
if layer_idx is None:
raise ValueError("layer_idx should not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,219 +0,0 @@
"""
LoLCATs + ThunderKittens linear attention + sliding window for generation
"""
import logging
from typing import Any, Callable, List, Optional
import torch
import torch.nn.functional as F
from .linear_attention import LinearAttentionState
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
LOG = logging.getLogger(__name__)
try:
from thunderkittens import hedgehog as tk_window_hedgehog_attention
LOG.debug("Successfully imported ThunderKittens for TK window attention")
except ImportError:
LOG.debug("Failed to import ThunderKittens for TK window attention")
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
def __init__(self, *args, window_size: int = 64, **kwargs):
super().__init__(*args, **kwargs)
self.train_attention = False
self.base_inference = False
self.window_size = 64 # hard-coded support for TK kernel
self.decode_window_size = 64
b, h, l, d = 1, 32, 8192, 128
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Any] = None, # “legacy” cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
assert (
past_key_value is not None
), "past_key_value must be provided for generation"
assert (
self.train_attention is False
), "train_attention is not supported for generation"
assert (
self.base_inference is False
), "base_inference is not supported for generation"
assert use_cache is True, "use_cache must be True for generation"
past_key_value.window_size = self.decode_window_size
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
f_q = self.feature_map_q(q)
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k
)
k_cache, v_cache, kv_state, k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
# Softmax attention terms
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
k.shape[-1] ** -0.5
)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
..., None
]
)
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Process prefill
# Use TK-implemented linear + terrace window attention
b, h, l, d = q.shape
device = q.device
# tk.hedgehog arguments
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
alphas = (
1 - betas
if self.affine_attention_factors
else torch.ones(betas.shape, dtype=torch.float32, device=device)
)
q_map = self.feature_map_q.mlp.layer
k_map = self.feature_map_k.mlp.layer
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
# 3. kv_state = torch.einsum(bhlf,bhld->bhfd,
# f_k[:, :, :-self.window_size],
# v[:, :, :-self.window_size]) # b, h, f, d
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
tk_window_hedgehog_attention(
q.contiguous(),
k.contiguous(),
v.contiguous(),
self.y_true,
self.k_state,
self.kv_state,
q_map,
k_map,
alphas,
betas,
)
past_key_value.update_with_kv(
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
)
# Concatenate heads and apply output projection
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, None, past_key_value
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.window_size = window_size
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
def update_with_kv(
self,
kv_state: torch.Tensor,
k_state: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_idx: int,
):
"""
Update the cache with new KV and K states
"""
if layer_idx == 0:
self._seen_tokens += k.shape[2]
self._seen_tokens_by_layer.append(k.shape[2])
# Initialize KV and K states
if len(self.decode_k_states) <= layer_idx:
self.decode_kv_states.append(kv_state)
self.decode_k_states.append(k_state)
else: # Update KV and K states
self.decode_kv_states[layer_idx] = (
self.decode_kv_states[layer_idx] + kv_state
)
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
self.k_cache.append(k[:, :, -self.window_size :, :])
self.v_cache.append(v[:, :, -self.window_size :, :])
def update_for_decoding(
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
):
"""
Update the cache for decoding
"""
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
k.dtype
)
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
if layer_idx == 0:
self._seen_tokens += k.shape[-2]
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,306 +0,0 @@
"""
LoLCATs attention combining sliding window and linear attentions
- Using the TK "terracing" arrangement
- Training over long sequences with fixed memory with recurrent view
- During attention transfer, use Flash Attention to compute softmax attention outputs
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import logging
from typing import Optional
import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from .linear_attention import softmax_attention
from .linear_window_attention_tk import LolcatsTKWindowAttention
LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
)
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(self, remove_base_attn=True, **kwargs):
# keep self.base_attn for Flash Attention inference
super().__init__(remove_base_attn=True, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
if self.train_attention and self.base_inference:
with torch.no_grad():
# LOG.debug(hidden_states.shape)
_y_true = flash_attention_2(
self, # self.base_attn,
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
# output_hidden_states=False,
use_cache=False,
)[0]
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
y_true = _y_true.reshape(b, l, -1).contiguous()
y_true = self.o_proj(y_true)
layer_io = (hidden_states, _y_true) # hack
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
return y_true, layer_io, None
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
else:
past_key_value.window_size = self.decode_window_size
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
k.shape[-1] ** -0.5
)
# a_sm = torch.softmax(a_sm, dim=-1)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
y_pred = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
..., None
]
)
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
if (
self.state_grad_enabled
and self.layer_idx == 0
and position_ids is not None
):
LOG.debug(
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
)
LOG.debug(
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
)
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
)
# Concatenate heads and apply output projection
_y_pred = y_pred.transpose(1, 2).contiguous()
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
if self.train_attention:
with torch.no_grad():
a_true = softmax_attention(q, k, None, causal=True)[1]
attn_weights = (_y_pred, (a_pred, a_true))
else:
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
return y_pred, attn_weights, past_key_value
# -----------------
# Flash Attention 2
# -----------------
def flash_attention_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
"""
Wrapper for LlamaFlashAttention2
Copied and modified from HF Transformers v4.36 and v4.43 implementations
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
try: # As in Transformers v4.36
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
except Exception: # As in Transformers v4.39
cos, sin = self.rotary_emb(key_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
LOG.debug(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self, "_flash_attention_forward", False):
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=True,
)
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0, # dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
return attn_output, past_key_value

View File

@@ -1,361 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Linear LLaMA model implementation."""
import logging
from functools import partial
from typing import Any, Optional
from torch import nn
from tqdm import tqdm
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__)
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
"""
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
"""
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
# Replace the attention layer with our custom attention
self.self_attn = convert_llama_attention(
layer=self, attention_config=config.attention_config
)
class LinearLlamaModel(LlamaModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
Args:
config: LinearLlamaConfig
"""
config_class = LinearLlamaConfig
base_model_prefix = "linear_llama"
def __init__(self, config: LinearLlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LinearLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class LinearLlamaForCausalLM(LlamaForCausalLM):
"""
Linear LLaMA model for causal language modeling.
"""
config_class = LinearLlamaConfig
base_model_prefix = "linear_llama"
def __init__(self, config):
super().__init__(config)
self.model = LinearLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@classmethod
def from_llama(
cls,
model: LlamaForCausalLM,
config: LinearLlamaConfig,
train_attention: bool = False,
remove_base_attn: bool = True,
) -> "LinearLlamaForCausalLM":
"""
Initialize a LinearLlamaForCausalLM from a LlamaModel
"""
if config is None:
raise ValueError("Missing config")
# initialize a new model with config
new_model = cls(config=config)
# remove the default model and lm_head
del new_model.model
del new_model.lm_head
# load converted model, lm_head, and vocab_size from llama model
new_model.model = convert_attention(
model.model,
attention_config=config.attention_config,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
new_model.lm_head = model.lm_head
new_model.vocab_size = model.vocab_size
return new_model
def toggle_attention(self, train: bool = True):
"""
Toggle attention to be trainable or not
"""
toggle_attention(self.model, train=train)
def remove_base_attention(self):
"""
Remove base attention after distillation
"""
remove_base_attention(self.model)
def convert_attention(
model: nn.Module,
attention_config: dict,
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Call to convert all attention layers
"""
# Get the layers to convert if provided
softmax_attns = attention_config.get("softmax_attentions", [])
# Get the attention to convert to
attention_type = attention_config.get("attention_type")
if attention_type != "softmax":
layers = traverse_layers(model)
for layer_idx, layer in enumerate(
tqdm(layers, desc="Converting attentions...")
):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer,
attention_config,
layers,
train_attention,
remove_base_attn,
)
layer.self_attn.converted = True
else:
# Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
LOG.info(
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
)
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, "base_attn", False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
LOG.info("-> Loading from model.model.layers")
except AttributeError as e: # if base model
if verbose:
LOG.info(e)
try:
layers = model.layers
if verbose:
LOG.info("-> Loading from model.layers")
except AttributeError as e1: # If we make a PEFT model
if verbose:
LOG.info(e1)
layers = model.base_model.model.model.layers
if verbose:
LOG.info("-> Loading from model.base_model.model.model.layers")
return layers
def convert_llama_attention(
layer: nn.Module,
attention_config: dict,
layers: Optional[list[nn.Module]] = None, # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1 if layers else None,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .linear_attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .linear_window_attention_tk import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .linear_window_attention_sw_linear import (
LolcatsLinearSlidingWindowAttention,
)
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
return None
def get_attention_cache(attention_type: str, past_key_values: Any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .linear_window_attention_tk import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
elif "softmax" in attention_type:
return past_key_values
else:
from .linear_attention import LinearAttentionState
return LinearAttentionState()
def register_linear_llama():
"""
Register Linear LLaMA model with the Transformers library.
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# registering for auto classes to save files
LinearLlamaConfig.register_for_auto_class("AutoConfig")
LinearLlamaModel.register_for_auto_class("AutoModel")
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")

View File

@@ -1,118 +0,0 @@
"""
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
"""
from typing import Any
from torch import Tensor, nn, tensor
from axolotl.core.trainers.base import AxolotlTrainer
class DistillAttentionXentMSETrainer(AxolotlTrainer):
"""
Custom trainer class for distilling attentions.
- We compute and store the attention outputs and/or weights for each head and layer,
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
"""
def __init__(
self,
model: nn.Module,
mse_factor: float = 1e3,
xent_factor: float = 0,
**kwargs: Any,
):
super().__init__(model=model, **kwargs)
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
self.criterion_mse = nn.MSELoss(reduction="mean")
self.mse_factor = mse_factor
self.xent_factor = xent_factor
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
self.model_accepts_loss_kwargs = False # added to combat explosive loss
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Tensor],
return_outputs=False,
num_items_in_batch=None,
) -> tuple[Tensor, dict]:
"""
Attention distillation ("attention transfer")
- For each layer and head, get attentions and train to
minimize some combo of MSE and cross-entropy loss
"""
# alias inputs to data
data = inputs
device = model.device
# Filter out labels
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
# set num_items_in_batch
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
# Forward pass
outputs = model(**inputs, output_attentions=True, use_cache=False)
outputs = outputs.get("attentions")
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
# n_layers x (predicted_attns, true_attns)
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
loss_mse = tensor(0.0, device=device)
loss_xent = tensor(0.0, device=device)
n_layers = 0 # Number of layers to distill
softmax_layers = []
for layer_idx, attns in enumerate(outputs):
if attns is not None:
if len(attns) != 2:
attns = attns.cpu()
else:
if self.xent_factor > 0:
# Cross-entropy loss
a_pred, a_true = attns[0]
a_pred = a_pred.clamp(
min=1e-12
).log() # nn.CrossEntropy assumes unnormalized logits
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
# Compute mean cross-entropy over all queries
a_pred = a_pred.contiguous().view(-1, k_len)
a_true = a_true.contiguous().view(-1, k_len)
loss_xent += self.criterion_xent(a_pred, a_true)
if self.mse_factor > 0:
loss_mse += self.criterion_mse(*attns[1])
n_layers += 1
else:
softmax_layers.append(layer_idx)
if n_layers > 0:
loss_xent = loss_xent / n_layers * self.xent_factor
loss_mse = loss_mse / n_layers * self.mse_factor
loss = loss_xent + loss_mse
if "position_ids" in data:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"input_len": data["position_ids"].shape[1],
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
else:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
return (loss, outputs) if return_outputs else loss

View File

@@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
if len(strategy.split(".")) == 1:
strategy = strategy + ".default"
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", module_base)
if len(strategy.split(".")) > 1:
try:
importlib.import_module(
strategy.split(".")[-2],
".".join(strategy.split(".")[:-2]),
)
module_base = ".".join(strategy.split(".")[:-2])
strategy = strategy.split(".")[-2]
except ModuleNotFoundError:
strategy = "." + ".".join(strategy.split(".")[:-1])
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught

View File

@@ -47,7 +47,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
@@ -70,7 +70,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][

View File

@@ -0,0 +1,14 @@
"""
DPO prompt strategies passthrough/zero-processing strategy
"""
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(
sample, tokenizer=None
): # pylint: disable=possibly-unused-variable,unused-argument
return sample
return transform_fn

View File

@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
from .trl import TrlConfig
LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -33,6 +35,7 @@ class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
@@ -663,14 +666,20 @@ class AxolotlInputConfig(
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: Optional[bool] = None
rl: Optional[RLType] = None
trl: Optional[TrlConfig] = Field(
default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda
)
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore

View File

@@ -0,0 +1,32 @@
"""
GRPO specific configuration args
"""
from typing import List, Optional
from pydantic import BaseModel, Field
class TrlConfig(BaseModel):
"""
Input args for TRL.
"""
beta: Optional[float] = None
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
},
)
# GRPO specific args
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64

View File

@@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
dataset.save_to_disk(str(prepared_ds_path))
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
@@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
**map_kwargs,
)
return data_set
@@ -150,36 +151,45 @@ def load_prepare_preference_datasets(cfg):
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl == "kto":
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
if not cfg.skip_prepare_dataset:
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)

View File

@@ -46,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
md5,
retry_on_request_exceptions,
)
@@ -56,7 +57,7 @@ from axolotl.utils.trainer import (
process_datasets_for_packing,
)
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)
@retry_on_request_exceptions(max_retries=3, delay=5)
@@ -339,8 +340,11 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if cfg.sample_packing and not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if not cfg.skip_prepare_dataset:
dataset = drop_long_seq_in_dataset(dataset, cfg)
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")

View File

@@ -1,4 +1,5 @@
"""data handling helpers"""
import functools
import hashlib
import logging
@@ -6,10 +7,15 @@ import time
from enum import Enum
import huggingface_hub
import numpy as np
import requests
from datasets import Dataset
from datasets import Dataset, IterableDataset
LOG = logging.getLogger("axolotl")
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
LOG = logging.getLogger(__name__)
class RetryStrategy(Enum):
@@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
)
return train_dataset, eval_dataset, dataset
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset
drop_long = functools.partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)
try:
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass
try:
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -1053,9 +1053,12 @@ class ModelLoader:
if self.cfg.resize_token_embeddings_to_32x
else len(self.tokenizer)
)
if (
hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings != embeddings_len
if hasattr(self.model, "get_input_embeddings") and (
self.model.get_input_embeddings().num_embeddings < embeddings_len
or (
self.model.get_input_embeddings().num_embeddings > embeddings_len
and self.cfg.shrink_embeddings
)
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
else:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions"""
import json
import math
import os
@@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
Works for both single-example (list[int]) or batched (list[list[int]]).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
@@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
try:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
except AttributeError:
pass
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
@@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_long,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
def drop_no_trainable_tokens(sample):
"""
Drop samples if all labels are -100 (i.e., zero trainable tokens).
@@ -325,6 +274,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
@@ -622,7 +576,7 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"):
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

@@ -33,7 +33,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"num_labels": 1,
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"sequence_len": 2048,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,