Compare commits
25 Commits
quantize-p
...
feat/linea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13d458d0ae | ||
|
|
ebd406af1d | ||
|
|
caa49a9d7d | ||
|
|
c15ea6b956 | ||
|
|
578fa764c8 | ||
|
|
0e6efaa10c | ||
|
|
c4cb622590 | ||
|
|
0f82bd2d18 | ||
|
|
49746b184f | ||
|
|
9e1c4de13c | ||
|
|
2d5f692fc0 | ||
|
|
2fd5c45c2e | ||
|
|
8294e6218f | ||
|
|
253dcdd0cf | ||
|
|
4cc60df876 | ||
|
|
2bc7833a4e | ||
|
|
1fb8d86396 | ||
|
|
adeefc1991 | ||
|
|
fb88269dcb | ||
|
|
433cf4a8c7 | ||
|
|
0b7b58c8be | ||
|
|
81731adc1d | ||
|
|
a1715aa317 | ||
|
|
ce0cd470f7 | ||
|
|
311d6eb5da |
135
src/axolotl/cli/convert_linear_attention.py
Normal file
135
src/axolotl/cli/convert_linear_attention.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
|
from axolotl.cli.config import load_cfg
|
||||||
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
|
||||||
|
LinearLlamaConfig,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||||
|
LinearLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model_config
|
||||||
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
|
"""
|
||||||
|
Convert attention to linear attention and perform attention transfer via distillation.
|
||||||
|
"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
|
|
||||||
|
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
|
||||||
|
cfg.load_in_8bit = False
|
||||||
|
cfg.load_in_4bit = False
|
||||||
|
cfg.adapter = None
|
||||||
|
|
||||||
|
# load model
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
|
# freeze model
|
||||||
|
for p in model.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
# convert to linear llama
|
||||||
|
linear_llama_config = LinearLlamaConfig.from_llama(
|
||||||
|
model.config, cfg.attention_config
|
||||||
|
)
|
||||||
|
model = LinearLlamaForCausalLM.from_llama(
|
||||||
|
model, config=linear_llama_config, train_attention=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# set save_path, save tokenizer and model config.
|
||||||
|
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||||
|
tokenizer.save_pretrained(save_path)
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
model.config.save_pretrained(save_path)
|
||||||
|
|
||||||
|
# Get datasets
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
train_dataset = dataset_meta.train_dataset
|
||||||
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
|
# toggle attention to be trainable
|
||||||
|
model.toggle_attention(train=True)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = setup_trainer(
|
||||||
|
cfg=cfg,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
model=(model, None, None),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=None,
|
||||||
|
total_num_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# train
|
||||||
|
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||||
|
|
||||||
|
# drop base_attention + remove training attn
|
||||||
|
model.toggle_attention(train=False)
|
||||||
|
model.remove_base_attention()
|
||||||
|
|
||||||
|
# NOTE: If in peft mode, consider whether to auto-merge
|
||||||
|
|
||||||
|
# save model
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
||||||
|
model.save_pretrained(save_path, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
|
||||||
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Path to `axolotl` config YAML file.
|
||||||
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
|
"""
|
||||||
|
# load cfg, force linearize and add plugin to linearize
|
||||||
|
parsed_cfg = load_cfg(
|
||||||
|
config,
|
||||||
|
linearize=True,
|
||||||
|
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = HfArgumentParser(TrainerCliArgs)
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
|
||||||
|
do_linearize(parsed_cfg, parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
|
fire.Fire(do_cli)
|
||||||
201
src/axolotl/integrations/lolcats/LICENSE
Normal file
201
src/axolotl/integrations/lolcats/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
44
src/axolotl/integrations/lolcats/README.md
Normal file
44
src/axolotl/integrations/lolcats/README.md
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
|
||||||
|
|
||||||
|
https://github.com/HazyResearch/lolcats/
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||||
|
|
||||||
|
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||||
|
# nano setup.py
|
||||||
|
|
||||||
|
# Build the CUDA kernel
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
Step 1:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.lolcats.LinearizePlugin
|
||||||
|
|
||||||
|
linearize: true
|
||||||
|
```
|
||||||
|
|
||||||
|
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
|
||||||
|
|
||||||
|
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
|
|
||||||
|
# with optional config below but this requires patching axolotl
|
||||||
|
# to allow this config to work with lora
|
||||||
|
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
||||||
|
```
|
||||||
|
|
||||||
|
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
|
||||||
|
|
||||||
|
Step 3: Run inference on the finetuned model
|
||||||
|
|
||||||
|
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`
|
||||||
43
src/axolotl/integrations/lolcats/__init__.py
Normal file
43
src/axolotl/integrations/lolcats/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""
|
||||||
|
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
|
||||||
|
|
||||||
|
Low-rank Linear Conversion via Attention Transfer
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
|
||||||
|
DistillAttentionXentMSETrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.lolcats")
|
||||||
|
|
||||||
|
|
||||||
|
class LinearizePlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for lolcats integration with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Register the Linear Llama model with transformers
|
||||||
|
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||||
|
register_linear_llama,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_linear_llama()
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg):
|
||||||
|
# defualt to XentMSE
|
||||||
|
# TODO: add check to allow MSE_linear
|
||||||
|
if cfg.linearize:
|
||||||
|
return DistillAttentionXentMSETrainer
|
||||||
|
|
||||||
|
return None
|
||||||
47
src/axolotl/integrations/lolcats/args.py
Normal file
47
src/axolotl/integrations/lolcats/args.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
Module for handling linear attention input arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMapKwargs(BaseModel):
|
||||||
|
"""Args for feature map"""
|
||||||
|
|
||||||
|
eps: float
|
||||||
|
mlp: Optional[None] = None
|
||||||
|
fullspace: bool
|
||||||
|
|
||||||
|
|
||||||
|
class LearnedKernelKwargs(BaseModel):
|
||||||
|
"""Args for learned kernel"""
|
||||||
|
|
||||||
|
feature_dim: int
|
||||||
|
skip_connection: bool
|
||||||
|
bias: bool
|
||||||
|
zero_init: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionConfig(BaseModel):
|
||||||
|
"""Args for attention config"""
|
||||||
|
|
||||||
|
attention_type: str
|
||||||
|
feature_map: str
|
||||||
|
feature_map_kwargs: FeatureMapKwargs
|
||||||
|
layer_idx: Optional[None] = None
|
||||||
|
learned_kernel: str
|
||||||
|
learned_kernel_kwargs: LearnedKernelKwargs
|
||||||
|
tie_qk_kernels: bool
|
||||||
|
train_qk: bool
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
attention_config: AttentionConfig
|
||||||
|
|
||||||
|
linearize: Optional[bool] = False
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Linear LLaMA model configuration"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLlamaConfig(LlamaConfig):
|
||||||
|
"""
|
||||||
|
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
|
||||||
|
It is a modified LlamaConfig that includes additional parameters for linear attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_config (`dict`):
|
||||||
|
Dictionary containing the configuration for linear attention mechanism.
|
||||||
|
Expected contents:
|
||||||
|
`attention_type` (str):
|
||||||
|
The type of attention to convert to.
|
||||||
|
`feature_map` (`str`):
|
||||||
|
The type of feature map to use for linear attention.
|
||||||
|
`feature_map_kwargs` (`dict`):
|
||||||
|
Additional arguments for the feature map.
|
||||||
|
`learned_kernel` (`str`, *optional*):
|
||||||
|
Type of learned kernel to use, if any.
|
||||||
|
`learned_kernel_kwargs` (`dict`, *optional*):
|
||||||
|
Additional arguments for the learned kernel.
|
||||||
|
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to tie query and key kernels.
|
||||||
|
`rotary_config` (`dict`, *optional*):
|
||||||
|
Configuration for rotary embeddings.
|
||||||
|
`train_attention` (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to train attention to match softmax attention.
|
||||||
|
`remove_base_attn` (`bool`, *optional*, defaults to True):
|
||||||
|
Whether to remove base attention after initialization.
|
||||||
|
`mask_value` (`int`, *optional*, defaults to 0):
|
||||||
|
Value to use for masking.
|
||||||
|
`eps` (`float`, *optional*, defaults to 1e-12):
|
||||||
|
Epsilon value for numerical stability.
|
||||||
|
`fp32_attention` (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to use fp32 precision for attention computation.
|
||||||
|
`track_state_grads` (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to track gradients of attention states.
|
||||||
|
|
||||||
|
**kwargs:
|
||||||
|
Additional arguments inherited from LlamaConfig.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "linear_llama"
|
||||||
|
|
||||||
|
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Set auto_map
|
||||||
|
self.auto_map = {
|
||||||
|
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
||||||
|
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set default attention config if none provided
|
||||||
|
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||||
|
"""
|
||||||
|
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llama_config (:class:`~transformers.LlamaConfig`):
|
||||||
|
The LlamaConfig to inherit from.
|
||||||
|
|
||||||
|
attention_config (`dict`):
|
||||||
|
Dictionary containing the configuration for linear attention mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return cls(attention_config=attention_config, **llama_config.to_dict())
|
||||||
30
src/axolotl/integrations/lolcats/linear_llama/csrc/README.md
Normal file
30
src/axolotl/integrations/lolcats/linear_llama/csrc/README.md
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Causal linear attention CUDA kernel
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```bash
|
||||||
|
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||||
|
|
||||||
|
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||||
|
# nano setup.py
|
||||||
|
|
||||||
|
# Build the CUDA kernel
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
Reference: https://github.com/idiap/fast-transformers/
|
||||||
|
|
||||||
|
```bib
|
||||||
|
@inproceedings{katharopoulos_et_al_2020,
|
||||||
|
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
|
||||||
|
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
|
||||||
|
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
|
||||||
|
year = {2020}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{vyas_et_al_2020,
|
||||||
|
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
|
||||||
|
title={Fast Transformers with Clustered Attention},
|
||||||
|
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
|
||||||
|
year={2020}
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||||
|
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||||
|
# Apoorv Vyas <avyas@idiap.ch>
|
||||||
|
#
|
||||||
|
from .causal_attention import causal_dot_product
|
||||||
@@ -0,0 +1,225 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||||
|
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||||
|
// Apoorv Vyas <avyas@idiap.ch>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute a*b^T and save it into out.
|
||||||
|
*
|
||||||
|
* a \in R^A
|
||||||
|
* b \in R^B
|
||||||
|
*/
|
||||||
|
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
||||||
|
for (int i=0; i<A; i++) {
|
||||||
|
float * bi = b;
|
||||||
|
for (int j=0; j<B; j++) {
|
||||||
|
*out += (*a) * (*bi);
|
||||||
|
out++;
|
||||||
|
bi++;
|
||||||
|
}
|
||||||
|
a++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implement a vector matrix product v*m and save it into out.
|
||||||
|
*
|
||||||
|
* v \in R^A
|
||||||
|
* m \in R^{AxB}
|
||||||
|
*/
|
||||||
|
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
||||||
|
// TODO: Consider removing the zeroing part and assuming out already
|
||||||
|
// contains 0s
|
||||||
|
for (int i=0; i<B; i++) {
|
||||||
|
out[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i=0; i<A; i++) {
|
||||||
|
float *oi = out;
|
||||||
|
for (int j=0; j<B; j++) {
|
||||||
|
*oi += (*v) * (*m);
|
||||||
|
oi++;
|
||||||
|
m++;
|
||||||
|
}
|
||||||
|
v++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implement a vector transposed-matrix product and save it into out.
|
||||||
|
*
|
||||||
|
* v \in R^B
|
||||||
|
* m \in R^{AxB}
|
||||||
|
*/
|
||||||
|
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
||||||
|
for (int i=0; i<A; i++) {
|
||||||
|
float *vi = v;
|
||||||
|
float s = 0;
|
||||||
|
for (int j=0; j<B; j++) {
|
||||||
|
s += (*vi) * (*m);
|
||||||
|
vi++;
|
||||||
|
m++;
|
||||||
|
}
|
||||||
|
// TODO: Should we be aggregating? See the comment on vm_dot.
|
||||||
|
*out = s;
|
||||||
|
out++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the causally masked dot products of queries, keys and values.
|
||||||
|
*
|
||||||
|
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
||||||
|
* computation is done efficiently by changing the order of the dot products.
|
||||||
|
*/
|
||||||
|
void causal_dot_product(
|
||||||
|
const torch::Tensor queries,
|
||||||
|
const torch::Tensor keys,
|
||||||
|
const torch::Tensor values,
|
||||||
|
torch::Tensor product
|
||||||
|
) {
|
||||||
|
// Extract some shapes
|
||||||
|
int N = queries.size(0);
|
||||||
|
int H = queries.size(1);
|
||||||
|
int L = queries.size(2);
|
||||||
|
int E = queries.size(3);
|
||||||
|
int M = values.size(3);
|
||||||
|
|
||||||
|
// Create accessors for all the arguments
|
||||||
|
auto qa = queries.accessor<float, 4>();
|
||||||
|
auto ka = keys.accessor<float, 4>();
|
||||||
|
auto va = values.accessor<float, 4>();
|
||||||
|
auto pa = product.accessor<float, 4>();
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2)
|
||||||
|
for (int n=0; n<N; n++) {
|
||||||
|
for (int h=0; h<H; h++) {
|
||||||
|
auto kv = torch::zeros({E, M}, queries.options());
|
||||||
|
float *kvp = kv.data_ptr<float>();
|
||||||
|
for (int l=0; l<L; l++) {
|
||||||
|
vvt_dot(
|
||||||
|
&ka[n][h][l][0],
|
||||||
|
&va[n][h][l][0],
|
||||||
|
kvp,
|
||||||
|
E,
|
||||||
|
M
|
||||||
|
);
|
||||||
|
vm_dot(
|
||||||
|
&qa[n][h][l][0],
|
||||||
|
kvp,
|
||||||
|
&pa[n][h][l][0],
|
||||||
|
E,
|
||||||
|
M
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the gradients of queries, keys and values given the gradient of the
|
||||||
|
* causal_dot_product output.
|
||||||
|
*
|
||||||
|
* Make sure that everything is computed in O(N D^2) complexity.
|
||||||
|
*/
|
||||||
|
void causal_dot_backward(
|
||||||
|
const torch::Tensor queries,
|
||||||
|
const torch::Tensor keys,
|
||||||
|
const torch::Tensor values,
|
||||||
|
const torch::Tensor grad_out,
|
||||||
|
torch::Tensor grad_queries,
|
||||||
|
torch::Tensor grad_keys,
|
||||||
|
torch::Tensor grad_values
|
||||||
|
) {
|
||||||
|
// Extract some shapes
|
||||||
|
int N = queries.size(0);
|
||||||
|
int H = queries.size(1);
|
||||||
|
int L = queries.size(2);
|
||||||
|
int E = queries.size(3);
|
||||||
|
int M = values.size(3);
|
||||||
|
|
||||||
|
// Create accessors for all the arguments
|
||||||
|
auto qa = queries.accessor<float, 4>();
|
||||||
|
auto ka = keys.accessor<float, 4>();
|
||||||
|
auto va = values.accessor<float, 4>();
|
||||||
|
auto ga = grad_out.accessor<float, 4>();
|
||||||
|
auto gqa = grad_queries.accessor<float, 4>();
|
||||||
|
auto gka = grad_keys.accessor<float, 4>();
|
||||||
|
auto gva = grad_values.accessor<float, 4>();
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2)
|
||||||
|
for (int n=0; n<N; n++) {
|
||||||
|
for (int h=0; h<H; h++) {
|
||||||
|
auto kv = torch::zeros({E, M}, queries.options());
|
||||||
|
float *kvp = kv.data_ptr<float>();
|
||||||
|
|
||||||
|
// Compute the gradient wrt the queries
|
||||||
|
for (int l=0; l<L; l++) {
|
||||||
|
vvt_dot(
|
||||||
|
&ka[n][h][l][0],
|
||||||
|
&va[n][h][l][0],
|
||||||
|
kvp,
|
||||||
|
E,
|
||||||
|
M
|
||||||
|
);
|
||||||
|
vmt_dot(
|
||||||
|
&ga[n][h][l][0],
|
||||||
|
kvp,
|
||||||
|
&gqa[n][h][l][0],
|
||||||
|
E,
|
||||||
|
M
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the gradient wrt the keys and values
|
||||||
|
kv.zero_();
|
||||||
|
for (int l=L-1; l>=0; l--) {
|
||||||
|
vvt_dot(
|
||||||
|
&qa[n][h][l][0],
|
||||||
|
&ga[n][h][l][0],
|
||||||
|
kvp,
|
||||||
|
E,
|
||||||
|
M
|
||||||
|
);
|
||||||
|
vmt_dot(
|
||||||
|
&va[n][h][l][0],
|
||||||
|
kvp,
|
||||||
|
&gka[n][h][l][0],
|
||||||
|
E,
|
||||||
|
M
|
||||||
|
);
|
||||||
|
vm_dot(
|
||||||
|
&ka[n][h][l][0],
|
||||||
|
kvp,
|
||||||
|
&gva[n][h][l][0],
|
||||||
|
E,
|
||||||
|
M
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"causal_dot_product",
|
||||||
|
&causal_dot_product,
|
||||||
|
"Compute the weighted sum of values but attending only to previous "
|
||||||
|
"values."
|
||||||
|
);
|
||||||
|
m.def(
|
||||||
|
"causal_dot_backward",
|
||||||
|
&causal_dot_backward,
|
||||||
|
"Compute the gradient of queries, keys and values given the gradient "
|
||||||
|
"of causal_dot_product."
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||||
|
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||||
|
# Apoorv Vyas <avyas@idiap.ch>
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
||||||
|
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
||||||
|
except ImportError as e:
|
||||||
|
print(e)
|
||||||
|
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
||||||
|
|
||||||
|
|
||||||
|
class CausalDotProduct(torch.autograd.Function):
|
||||||
|
"""Compute the weighted sum of values but attending only to previous
|
||||||
|
values."""
|
||||||
|
|
||||||
|
dot = {
|
||||||
|
# "cpu": causal_dot_product_cpu,
|
||||||
|
"cuda": causal_dot_product_cuda
|
||||||
|
}
|
||||||
|
dot_backward = {
|
||||||
|
# "cpu": causal_dot_backward_cpu,
|
||||||
|
"cuda": causal_dot_backward_cuda
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, Q, K, V):
|
||||||
|
# Save the inputs for the gradient computation
|
||||||
|
ctx.save_for_backward(Q, K, V)
|
||||||
|
|
||||||
|
# Create the output tensor
|
||||||
|
device = Q.device
|
||||||
|
N, H, L, _ = Q.shape
|
||||||
|
_, _, _, M = V.shape
|
||||||
|
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
||||||
|
|
||||||
|
# Actually perform the dot product
|
||||||
|
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||||
|
# breakpoint()
|
||||||
|
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||||
|
|
||||||
|
return product
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
# Extract the saved tensors
|
||||||
|
Q, K, V = ctx.saved_tensors
|
||||||
|
|
||||||
|
# Allocate memory for the gradients
|
||||||
|
grad_Q = torch.zeros_like(Q)
|
||||||
|
grad_K = torch.zeros_like(K)
|
||||||
|
grad_V = torch.zeros_like(V)
|
||||||
|
|
||||||
|
# Actually compute the gradients
|
||||||
|
CausalDotProduct.dot_backward[Q.device.type](
|
||||||
|
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
|
||||||
|
)
|
||||||
|
|
||||||
|
return grad_Q, grad_K, grad_V
|
||||||
|
|
||||||
|
|
||||||
|
# Alias the autograd functions to python style snake case naming
|
||||||
|
causal_dot_product = CausalDotProduct.apply
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
65
src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py
Normal file
65
src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||||
|
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||||
|
# Apoorv Vyas <avyas@idiap.ch>
|
||||||
|
#
|
||||||
|
|
||||||
|
import subprocess # nosec
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from setuptools import setup
|
||||||
|
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_arch_torch():
|
||||||
|
arch = torch.cuda.get_arch_list()[-1]
|
||||||
|
print(f"Found arch: {arch} from existing torch installation")
|
||||||
|
return arch
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_bare_metal_version(cuda_dir):
|
||||||
|
raw_output = subprocess.check_output(
|
||||||
|
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
|
||||||
|
)
|
||||||
|
output = raw_output.split()
|
||||||
|
release_idx = output.index("release") + 1
|
||||||
|
release = output[release_idx].split(".")
|
||||||
|
bare_metal_major = release[0]
|
||||||
|
bare_metal_minor = release[1][0]
|
||||||
|
return raw_output, bare_metal_major, bare_metal_minor
|
||||||
|
|
||||||
|
|
||||||
|
def append_nvcc_threads(nvcc_extra_args):
|
||||||
|
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||||
|
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||||
|
return nvcc_extra_args + ["--threads", "4"]
|
||||||
|
return nvcc_extra_args
|
||||||
|
|
||||||
|
|
||||||
|
arch = get_last_arch_torch()
|
||||||
|
sm_num = arch[-2:]
|
||||||
|
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
|
||||||
|
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
||||||
|
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
||||||
|
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
||||||
|
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="causal_attention_cuda_cpp",
|
||||||
|
ext_modules=[
|
||||||
|
CUDAExtension(
|
||||||
|
"causal_attention_cuda",
|
||||||
|
[
|
||||||
|
# 'causal_attention.cpp',
|
||||||
|
"causal_attention_cuda.cu",
|
||||||
|
],
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": ["-O3"],
|
||||||
|
"nvcc": append_nvcc_threads(
|
||||||
|
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
cmdclass={"build_ext": BuildExtension},
|
||||||
|
)
|
||||||
@@ -0,0 +1,856 @@
|
|||||||
|
"""
|
||||||
|
Linear attention classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||||
|
try:
|
||||||
|
from csrc import causal_dot_product as fast_causal_dot_product
|
||||||
|
except ImportError:
|
||||||
|
fast_causal_dot_product = None
|
||||||
|
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# Attention functions
|
||||||
|
# -------------------
|
||||||
|
|
||||||
|
|
||||||
|
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Causal linear attention dot product
|
||||||
|
- If available, use CUDA kernel from fast-transformers
|
||||||
|
"""
|
||||||
|
if fast_causal_dot_product is None:
|
||||||
|
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
|
||||||
|
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
|
||||||
|
return fast_causal_dot_product(q, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
fp32_attention: bool = False,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Compute linear attention with CUDA kernel implementation from fast-transformers
|
||||||
|
- https://github.com/idiap/fast-transformers
|
||||||
|
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
||||||
|
v is shape (b, h, l, head_dim)
|
||||||
|
"""
|
||||||
|
dtype = q.dtype
|
||||||
|
# Causal mask already applied
|
||||||
|
y = causal_dot_product(
|
||||||
|
q.contiguous().to(dtype=torch.float32),
|
||||||
|
k.contiguous().to(dtype=torch.float32),
|
||||||
|
v.contiguous().to(dtype=torch.float32),
|
||||||
|
)
|
||||||
|
if fp32_attention:
|
||||||
|
y = (
|
||||||
|
y
|
||||||
|
/ (
|
||||||
|
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
|
||||||
|
)[..., None]
|
||||||
|
).to(dtype=dtype)
|
||||||
|
else:
|
||||||
|
y = y.to(dtype=dtype)
|
||||||
|
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
||||||
|
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
||||||
|
return y, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: Optional[torch.Tensor] = None,
|
||||||
|
causal: bool = True,
|
||||||
|
fp32_attention: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Standard softmax attention; only compute outputs if v is not None
|
||||||
|
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
y = None
|
||||||
|
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
|
||||||
|
if causal: # Apply causal mask
|
||||||
|
m, n = a.shape[-2:]
|
||||||
|
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||||
|
n - m + 1
|
||||||
|
)
|
||||||
|
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
||||||
|
if fp32_attention:
|
||||||
|
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||||
|
else:
|
||||||
|
a = torch.softmax(a, dim=-1)
|
||||||
|
if v is not None:
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||||
|
return y, a, None
|
||||||
|
|
||||||
|
|
||||||
|
def quadratic_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: Optional[torch.Tensor] = None,
|
||||||
|
causal: bool = True,
|
||||||
|
fp32_attention: bool = False,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
||||||
|
-> Use for attention distillation
|
||||||
|
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
||||||
|
"""
|
||||||
|
y = None
|
||||||
|
dtype = q.dtype
|
||||||
|
if fp32_attention:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
|
||||||
|
if causal: # Apply causal mask
|
||||||
|
m, n = a.shape[-2:]
|
||||||
|
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||||
|
n - m + 1
|
||||||
|
)
|
||||||
|
a = a.masked_fill(causal_mask, 0)
|
||||||
|
# Normalize to compute attention
|
||||||
|
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
||||||
|
a = a.to(dtype=dtype) if fp32_attention else a
|
||||||
|
if torch.isnan(a).sum() > 0:
|
||||||
|
breakpoint()
|
||||||
|
if v is not None:
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||||
|
return y, a, None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsLinearAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
LoLCATs attention implementation initialized from a
|
||||||
|
`LlamaAttention` or `MistralAttention` object (base_attn)
|
||||||
|
|
||||||
|
Most of the arguments are directly tied to argparse args
|
||||||
|
- For now we don't support padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_attn: nn.Module, # like LlamaAttention
|
||||||
|
feature_map: str,
|
||||||
|
feature_map_kwargs: dict,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
max_layer_idx: Optional[int] = None,
|
||||||
|
learned_kernel: Optional[str] = None,
|
||||||
|
learned_kernel_kwargs: Optional[dict] = None,
|
||||||
|
tie_qk_kernels: Optional[bool] = False,
|
||||||
|
rotary_config: Optional[dict] = None,
|
||||||
|
train_attention: Optional[bool] = False,
|
||||||
|
remove_base_attn: bool = True,
|
||||||
|
attention_type: Optional[str] = "lolcats_llama",
|
||||||
|
mask_value: int = 0,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
fp32_attention: bool = False,
|
||||||
|
track_state_grads: bool = False,
|
||||||
|
rank: Optional[int] = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_config = getattr(base_attn, "config", None)
|
||||||
|
if self.base_config is not None:
|
||||||
|
self.base_config = self.base_config.to_dict()
|
||||||
|
self.attention_type = attention_type
|
||||||
|
self.mask_value = mask_value
|
||||||
|
self.eps = eps
|
||||||
|
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
|
||||||
|
self.max_layer_idx = max_layer_idx
|
||||||
|
self.tie_qk_kernels = tie_qk_kernels
|
||||||
|
self.train_attention = train_attention
|
||||||
|
self.base_inference = False
|
||||||
|
self.fp32_attention = fp32_attention
|
||||||
|
self.track_state_grads = track_state_grads
|
||||||
|
if rank == 0: # multi-gpu
|
||||||
|
if fp32_attention and layer_idx == 0:
|
||||||
|
print(f"-> fp32_attention is {fp32_attention}")
|
||||||
|
if layer_idx == 0 and feature_map_kwargs is not None:
|
||||||
|
for k, v in feature_map_kwargs.items():
|
||||||
|
print(f"-> {k}: {v}")
|
||||||
|
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
||||||
|
for k, v in learned_kernel_kwargs.items():
|
||||||
|
print(f"-> {k}: {v}")
|
||||||
|
|
||||||
|
self.remove_base_attn = remove_base_attn
|
||||||
|
|
||||||
|
self.init_weights_(base_attn, remove_base_attn)
|
||||||
|
self.init_feature_map_(
|
||||||
|
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_feature_map_(
|
||||||
|
self,
|
||||||
|
feature_map: str,
|
||||||
|
feature_map_kwargs: dict,
|
||||||
|
learned_kernel: Optional[str] = None,
|
||||||
|
learned_kernel_kwargs: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize MLP-based feature map
|
||||||
|
"""
|
||||||
|
self.fmap_gqa = False # Turn True if specified below
|
||||||
|
if learned_kernel is not None and learned_kernel_kwargs is not None:
|
||||||
|
# Ensure dict
|
||||||
|
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
||||||
|
learned_kernel_kwargs["num_heads"] = self.num_heads
|
||||||
|
learned_kernel_kwargs["head_dim"] = self.head_dim
|
||||||
|
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
|
||||||
|
learned_kernel_kwargs["device"] = self.q_proj.weight.device
|
||||||
|
# Create MLP
|
||||||
|
mlp_learned_kernel = init_learned_kernel(
|
||||||
|
learned_kernel, **learned_kernel_kwargs
|
||||||
|
)
|
||||||
|
# Add "activation"; see src.models.feature_map.py
|
||||||
|
self.feature_map_q = init_feature_map(
|
||||||
|
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
|
||||||
|
)
|
||||||
|
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
||||||
|
self.feature_map_k = self.feature_map_q
|
||||||
|
else:
|
||||||
|
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
||||||
|
|
||||||
|
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize module layers, weights, positional dependencies, etc.
|
||||||
|
from original softmax attention layer (base_attn)
|
||||||
|
"""
|
||||||
|
# Make other attributes accessible
|
||||||
|
self.attention_dropout = 0 # We don't use dropout
|
||||||
|
self.hidden_size = base_attn.config.hidden_size
|
||||||
|
self.num_heads = base_attn.config.num_attention_heads
|
||||||
|
self.head_dim = base_attn.head_dim
|
||||||
|
self.num_key_value_heads = base_attn.config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = base_attn.num_key_value_groups
|
||||||
|
|
||||||
|
self.q_shape = [self.num_heads, self.head_dim]
|
||||||
|
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
||||||
|
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
||||||
|
|
||||||
|
# Copy original model projection layers
|
||||||
|
self.q_proj = base_attn.q_proj
|
||||||
|
self.k_proj = base_attn.k_proj
|
||||||
|
self.v_proj = base_attn.v_proj
|
||||||
|
self.o_proj = base_attn.o_proj
|
||||||
|
try: # If wanting to use FA2 for ground-truth inference
|
||||||
|
self._flash_attn_uses_top_left_mask = (
|
||||||
|
base_attn._flash_attn_uses_top_left_mask
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.remove_base_attn or remove_base_attn:
|
||||||
|
del base_attn # We don't need to keep these around
|
||||||
|
else:
|
||||||
|
self.base_attn = base_attn # For some training runs helpful to just call
|
||||||
|
|
||||||
|
def process_qkv(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
past_key_value: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute queries, keys, and values
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
kv_seq_len = k.shape[-2]
|
||||||
|
|
||||||
|
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
||||||
|
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
||||||
|
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
||||||
|
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
if (
|
||||||
|
past_key_value is not None
|
||||||
|
): # and k.shape[2] > q.shape[2]: # e.g., when generating
|
||||||
|
past_key_value.window_size = getattr(
|
||||||
|
self, "decode_window_size", None
|
||||||
|
) # self.decode_window_size
|
||||||
|
if isinstance(
|
||||||
|
past_key_value, Cache
|
||||||
|
): # In Transformers v4.36+ this is a DynamicCache object
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(
|
||||||
|
kv_seq_len, self.layer_idx
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# Apply rotary embeddings
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||||
|
|
||||||
|
k = repeat_kv(k, self.num_key_value_groups)
|
||||||
|
v = repeat_kv(v, self.num_key_value_groups)
|
||||||
|
return q, k, v, kv_seq_len
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
||||||
|
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_embeddings, past_key_value
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.base_inference:
|
||||||
|
with torch.no_grad():
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
||||||
|
y_true = (
|
||||||
|
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
attn_weights = (None, None)
|
||||||
|
|
||||||
|
elif self.train_attention: # Distilling / learning attentions
|
||||||
|
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
||||||
|
assert (
|
||||||
|
output_attentions is True
|
||||||
|
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
|
||||||
|
with torch.no_grad():
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
||||||
|
y_true = (
|
||||||
|
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
|
||||||
|
# 2. Compute "predicted" attention (just weights)
|
||||||
|
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
||||||
|
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
||||||
|
attn_weights = ( # type: ignore
|
||||||
|
(attn_pred, attn_true),
|
||||||
|
(y_pred, _y_true),
|
||||||
|
) # Save both attention weights so we can supervise.
|
||||||
|
|
||||||
|
else: # Finetuning
|
||||||
|
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
||||||
|
# Apply prefill mask
|
||||||
|
if attention_mask is not None and q.shape[2] > 1:
|
||||||
|
if len(attention_mask.shape) == 4:
|
||||||
|
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
|
||||||
|
..., None
|
||||||
|
] # b, 1, k_len, 1
|
||||||
|
else:
|
||||||
|
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
||||||
|
k = k.masked_fill(~lin_attn_mask, 0)
|
||||||
|
|
||||||
|
if past_key_value is not None: # Initialize states
|
||||||
|
if len(past_key_value.kv_states) == self.layer_idx:
|
||||||
|
b, h, _, f = k.shape
|
||||||
|
past_key_value.kv_states.append(
|
||||||
|
torch.zeros(
|
||||||
|
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
past_key_value.k_states.append(
|
||||||
|
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
||||||
|
)
|
||||||
|
# Generating
|
||||||
|
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
||||||
|
assert use_cache is True
|
||||||
|
kv_state, k_state = past_key_value.update(
|
||||||
|
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
|
||||||
|
)
|
||||||
|
if self.fp32_attention:
|
||||||
|
q = q.float()
|
||||||
|
y_true = (
|
||||||
|
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
|
||||||
|
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
).to(dtype=k.dtype)
|
||||||
|
else:
|
||||||
|
y_true = (
|
||||||
|
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
|
||||||
|
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
y_true, _, _ = linear_attention(
|
||||||
|
q, k, v, self.fp32_attention, self.eps
|
||||||
|
) # Ordinarily the states are ignored
|
||||||
|
past_key_value.update(
|
||||||
|
k.detach(),
|
||||||
|
v.detach(),
|
||||||
|
self.layer_idx,
|
||||||
|
accumulate_in_fp32=self.fp32_attention,
|
||||||
|
)
|
||||||
|
# doing some unnecessary recomputation here
|
||||||
|
else:
|
||||||
|
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
||||||
|
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return y_true, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionState(Cache):
|
||||||
|
"""
|
||||||
|
Handle the KV and K states for linear attention
|
||||||
|
- Adopts HF Transformers `past_key_values` convention
|
||||||
|
- Inherits from `Cache` class
|
||||||
|
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""
|
||||||
|
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
||||||
|
"""
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
||||||
|
self._seen_tokens_by_layer.append(0)
|
||||||
|
return self._seen_tokens_by_layer[layer_idx]
|
||||||
|
|
||||||
|
def get_max_length(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_usable_length(
|
||||||
|
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
||||||
|
) -> int:
|
||||||
|
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||||
|
# Cache without size limit -> all cache is usable
|
||||||
|
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||||
|
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
||||||
|
max_length = self.get_max_length()
|
||||||
|
previous_seq_length = self.get_seq_length(layer_idx)
|
||||||
|
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
||||||
|
return max_length - new_seq_length
|
||||||
|
return previous_seq_length
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
key_states, value_states = key_states.float(), value_states.float()
|
||||||
|
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", key_states, value_states
|
||||||
|
).detach()
|
||||||
|
k_state = key_states.sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
).detach() # b, h, 1, f; note the 1
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
print(
|
||||||
|
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
|
||||||
|
)
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
else:
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def to_legacy_cache(self):
|
||||||
|
"""Hack, but just return self"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
|
"""
|
||||||
|
Reorders the cache for beam search, given the selected beam indices.
|
||||||
|
-> Copied from transformers/src/transformers/cache_utils.py
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Reordering cache not implemented for LinearAttentionState"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# feature map functions
|
||||||
|
# -------------------
|
||||||
|
|
||||||
|
|
||||||
|
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map final activation for linear attention
|
||||||
|
"""
|
||||||
|
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map final activation for linear attention
|
||||||
|
"""
|
||||||
|
if name == "softmax_dim" and fullspace:
|
||||||
|
return SoftmaxDim(**kwargs)
|
||||||
|
elif name == "softmax_dim" and not fullspace:
|
||||||
|
return SoftmaxDimHalfspace(**kwargs)
|
||||||
|
elif name == "exp_dim" and fullspace:
|
||||||
|
return Exp(**kwargs)
|
||||||
|
elif name == "exp_dim" and not fullspace:
|
||||||
|
return ExpHalfspace(**kwargs)
|
||||||
|
elif name == "pos_elu":
|
||||||
|
return PosELU(**kwargs)
|
||||||
|
elif name == "relu":
|
||||||
|
return ReLU(**kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def init_learned_kernel(name: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map MLP for linear attention
|
||||||
|
"""
|
||||||
|
if name == "untied_head_einsum":
|
||||||
|
return FeatureMapMLP(**kwargs)
|
||||||
|
elif name == "untied_head_adapter":
|
||||||
|
return FeatureMapAdapter(**kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMap(nn.Module):
|
||||||
|
"""
|
||||||
|
Final 'activation' of feature map. Can probably be combined with
|
||||||
|
`FeatureMapMLP` below
|
||||||
|
|
||||||
|
Full feature map is like f(xW + b)
|
||||||
|
-> This is the `f` part
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_name: str,
|
||||||
|
head_dim_idx: int = -1,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mlp: Optional[nn.Module] = None,
|
||||||
|
fullspace: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim_idx = head_dim_idx
|
||||||
|
self.eps = eps
|
||||||
|
self.mlp = mlp if mlp is not None else nn.Identity()
|
||||||
|
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
||||||
|
|
||||||
|
def q_map(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Use for inference in case q and k feature maps differ
|
||||||
|
"""
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def k_map(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Use for inference in case q and k feature maps differ
|
||||||
|
"""
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# Feature map activations
|
||||||
|
# -----------------------
|
||||||
|
class FeatureMapAct(nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for feature map activations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps: float = 1e-12):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PosELU(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return (1 + F.elu(x)).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class ReLU(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
ReLU activation as in https://arxiv.org/abs/2103.13076
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return F.relu(x).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxDim(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return torch.cat(
|
||||||
|
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
||||||
|
).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxDimHalfspace(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Exp(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||||
|
x_min = torch.amin(x, dim=-1, keepdim=True)
|
||||||
|
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
||||||
|
min=self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpHalfspace(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||||
|
return torch.exp(x - x_max).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------
|
||||||
|
# Feature map MLPs
|
||||||
|
# ----------------
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMapMLP(nn.Module):
|
||||||
|
"""
|
||||||
|
Learnable MLP in feature map.
|
||||||
|
|
||||||
|
Full feature map is like f(xW + b)
|
||||||
|
-> This is the `W` and (optional) `b` part
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int, # input dim
|
||||||
|
feature_dim: int, # output dim
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
skip_connection: bool = False,
|
||||||
|
bias: bool = False,
|
||||||
|
zero_init: bool = False,
|
||||||
|
normal_init: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
self.skip_connection = skip_connection
|
||||||
|
self.bias = bias
|
||||||
|
self.zero_init = zero_init
|
||||||
|
self.normal_init = normal_init
|
||||||
|
self.init_weights_()
|
||||||
|
|
||||||
|
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
||||||
|
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
||||||
|
|
||||||
|
if self.normal_init:
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.normal_(self.layer)
|
||||||
|
|
||||||
|
if self.skip_connection:
|
||||||
|
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
||||||
|
assert self.head_dim == self.feature_dim, assertion_fail
|
||||||
|
|
||||||
|
def init_weights_(self):
|
||||||
|
"""
|
||||||
|
Initialize (W)eights and (b)iases
|
||||||
|
"""
|
||||||
|
self.layer = nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
(self.num_heads, self.head_dim, self.feature_dim),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.layer)
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
self.bias = nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
(1, self.num_heads, 1, 1), # self.feature_dim),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.bias)
|
||||||
|
else:
|
||||||
|
self.bias = 0.0 # hack
|
||||||
|
|
||||||
|
def zero_init_with_skip_(self):
|
||||||
|
"""
|
||||||
|
Initialize weights to zero matrix if skip connection
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.zeros_(self.layer)
|
||||||
|
|
||||||
|
def zero_init_(self):
|
||||||
|
"""
|
||||||
|
Initialize weights to identity matrix if no skip connection
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(self.layer.shape[0]):
|
||||||
|
try:
|
||||||
|
nn.init.eye_(self.layer[i])
|
||||||
|
except RuntimeError:
|
||||||
|
with torch.no_grad():
|
||||||
|
dtype = self.layer[i].dtype
|
||||||
|
weight = torch.eye(
|
||||||
|
*self.layer[i].shape,
|
||||||
|
requires_grad=self.layer[i].requires_grad,
|
||||||
|
device=self.layer[i].device,
|
||||||
|
)
|
||||||
|
self.layer[i] = weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
||||||
|
return x + _x if self.skip_connection else _x
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMapAdapter(FeatureMapMLP):
|
||||||
|
"""
|
||||||
|
Learnable Feature map with bottleneck adapter
|
||||||
|
as in https://arxiv.org/abs/1902.00751
|
||||||
|
|
||||||
|
We don't use but could be fun to try
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int, *args, **kwargs):
|
||||||
|
kwargs["skip_connection"] = True
|
||||||
|
kwargs["bias"] = True
|
||||||
|
kwargs["zero_init"] = True
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_weights_(self):
|
||||||
|
"""
|
||||||
|
Initialize (W)eights and (b)iases
|
||||||
|
"""
|
||||||
|
kwargs = {"dtype": self.dtype, "device": self.device}
|
||||||
|
self.layer0 = nn.Parameter(
|
||||||
|
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
||||||
|
)
|
||||||
|
self.layer1 = nn.Parameter(
|
||||||
|
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.layer0)
|
||||||
|
nn.init.kaiming_uniform_(self.layer1)
|
||||||
|
|
||||||
|
self.bias0 = nn.Parameter(
|
||||||
|
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
||||||
|
)
|
||||||
|
self.bias1 = nn.Parameter(
|
||||||
|
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.bias0)
|
||||||
|
nn.init.kaiming_uniform_(self.bias1)
|
||||||
|
|
||||||
|
def zero_init_with_skip_(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.zeros_(self.layer0)
|
||||||
|
nn.init.zeros_(self.layer1)
|
||||||
|
nn.init.zeros_(self.bias0)
|
||||||
|
nn.init.zeros_(self.bias1)
|
||||||
|
|
||||||
|
def zero_init_(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
-> Down-project, apply nonlinearity, up-project; add skip connection
|
||||||
|
"""
|
||||||
|
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
||||||
|
_x = F.relu(_x)
|
||||||
|
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
||||||
|
return x + _x if self.skip_connection else _x
|
||||||
@@ -0,0 +1,460 @@
|
|||||||
|
"""
|
||||||
|
Subquadratic attention combining sliding window and linear attentions
|
||||||
|
- Using "standard" sliding windows
|
||||||
|
- Didactically computes outputs with n^2 attention weights for now
|
||||||
|
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
from .linear_attention import (
|
||||||
|
LinearAttentionState,
|
||||||
|
LolcatsLinearAttention,
|
||||||
|
softmax_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Sliding window helpers
|
||||||
|
# ----------------------
|
||||||
|
def get_masks(
|
||||||
|
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Return masks for softmax and linear attention terms
|
||||||
|
-> 1 is include, 0 is ignore
|
||||||
|
"""
|
||||||
|
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
k_len - q_len
|
||||||
|
)
|
||||||
|
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
k_len - q_len - window_size
|
||||||
|
)
|
||||||
|
window_mask = causal_mask - linear_mask
|
||||||
|
# Return softmax mask (window), linear attention mask
|
||||||
|
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||||
|
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_quadratic(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hybrid attention combining sliding window and linear attentions
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask_window, mask_linear = get_masks(
|
||||||
|
window_size, q.shape[-2], k.shape[-2], q.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||||
|
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||||
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||||
|
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||||
|
# Allow outputs to also depend on prior kv_state and k_state
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||||
|
if (
|
||||||
|
kv_state is not None and k_state is not None
|
||||||
|
): # Combine with prior kv_state and k_state
|
||||||
|
y += linear_factor * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln += (
|
||||||
|
linear_factor
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||||
|
)
|
||||||
|
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
return y, a # attention weights only for the last chunk
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size: int = 64,
|
||||||
|
decode_window_size: Optional[int] = None,
|
||||||
|
affine_attention_factors: bool = False,
|
||||||
|
init_window_factor: float = 0,
|
||||||
|
train_window_factor: bool = True,
|
||||||
|
state_grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.decode_window_size = (
|
||||||
|
decode_window_size if decode_window_size is not None else window_size
|
||||||
|
)
|
||||||
|
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
|
||||||
|
# Determine how we compute attentions
|
||||||
|
self.quadratic_attention = hybrid_attention_quadratic
|
||||||
|
self.attention_type = kwargs[
|
||||||
|
"attention_type"
|
||||||
|
] # 'hedgehog_long_llama_window_sw'
|
||||||
|
# Learnable factor for combining attentions
|
||||||
|
self.affine_attention_factors = affine_attention_factors
|
||||||
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||||
|
if train_window_factor:
|
||||||
|
self.window_factors = nn.Parameter(
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"window_factors",
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||||
|
self.base_inference = False
|
||||||
|
self.state_grad_enabled = state_grad_enabled
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||||
|
k
|
||||||
|
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
with torch.no_grad():
|
||||||
|
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||||
|
y_true = (
|
||||||
|
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
|
||||||
|
# 2. Compute "predicted" attention outputs
|
||||||
|
# compute attn weights under sliding window
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||||
|
else:
|
||||||
|
attn_weights = None
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = a_pred
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if (
|
||||||
|
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||||
|
): # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum(
|
||||||
|
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||||
|
) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum(
|
||||||
|
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||||
|
)[..., None]
|
||||||
|
)
|
||||||
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, _ = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self.layer_idx,
|
||||||
|
fmap_key_states=f_k,
|
||||||
|
accumulate_in_fp32=True,
|
||||||
|
)
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
return y_true, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Account for sliding windows
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = False,
|
||||||
|
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||||
|
grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Update KV, K states; and KV cache during training
|
||||||
|
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||||
|
up to sliding window terms
|
||||||
|
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||||
|
up to end of sequence
|
||||||
|
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||||
|
"""
|
||||||
|
if fmap_key_states is None:
|
||||||
|
raise ValueError("fmap_key_states must not be None")
|
||||||
|
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(grad_enabled):
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
# key_states = key_states.float()
|
||||||
|
fmap_key_states = fmap_key_states.float()
|
||||||
|
value_states = value_states.float()
|
||||||
|
|
||||||
|
# Decoding KV state (KV terms up to last window_size)
|
||||||
|
decode_kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, : -self.window_size],
|
||||||
|
value_states[:, :, : -self.window_size],
|
||||||
|
)
|
||||||
|
# KV state
|
||||||
|
kv_state = decode_kv_state + torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, -self.window_size :],
|
||||||
|
value_states[:, :, -self.window_size :],
|
||||||
|
)
|
||||||
|
# shape is b, h, 1, f; note the 1
|
||||||
|
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
|
||||||
|
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||||
|
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||||
|
|
||||||
|
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(
|
||||||
|
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||||
|
)
|
||||||
|
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||||
|
else:
|
||||||
|
# Update kv and k states recurrently
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
|
||||||
|
decode_kv_state = (
|
||||||
|
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||||
|
+ decode_kv_state
|
||||||
|
).to(dtype)
|
||||||
|
decode_k_state = (
|
||||||
|
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||||
|
).to(dtype)
|
||||||
|
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||||
|
self.decode_k_states[layer_idx] = decode_k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||||
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self,
|
||||||
|
keys: torch.Tensor,
|
||||||
|
values: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
feature_map_k: Callable,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the decoding KV and K states, and KV cache, during decodeing
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
|
||||||
|
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||||
|
else:
|
||||||
|
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||||
|
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||||
|
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||||
|
# else:
|
||||||
|
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||||
|
).to(
|
||||||
|
dtype
|
||||||
|
) # b, h, f, d
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat(
|
||||||
|
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||||
|
)
|
||||||
|
self.v_cache[layer_idx] = torch.cat(
|
||||||
|
[v_cache[:, :, 1:, :], values], dim=-2
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += keys.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
@@ -0,0 +1,685 @@
|
|||||||
|
"""
|
||||||
|
Subquadratic attention combining sliding window and linear attentions
|
||||||
|
- Using "standard" sliding windows
|
||||||
|
- Didactically computes outputs with n^2 attention weights for now
|
||||||
|
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||||
|
from .linear_attention import (
|
||||||
|
LinearAttentionState,
|
||||||
|
LolcatsLinearAttention,
|
||||||
|
causal_dot_product,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Sliding window helpers
|
||||||
|
# ----------------------
|
||||||
|
def get_masks(
|
||||||
|
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Return masks for softmax and linear attention terms
|
||||||
|
-> 1 is include, 0 is ignore
|
||||||
|
"""
|
||||||
|
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
max(k_len - q_len, 0)
|
||||||
|
)
|
||||||
|
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
max(k_len - q_len, 0) - window_size
|
||||||
|
)
|
||||||
|
window_mask = causal_mask - linear_mask
|
||||||
|
# Return softmax mask (window), linear attention mask
|
||||||
|
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||||
|
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_quadratic(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hybrid attention combining sliding window and linear attentions
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask_window, mask_linear = get_masks(
|
||||||
|
window_size, q.shape[-2], k.shape[-2], q.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||||
|
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||||
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||||
|
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||||
|
# Allow outputs to also depend on prior kv_state and k_state
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||||
|
if (
|
||||||
|
kv_state is not None and k_state is not None
|
||||||
|
): # Combine with prior kv_state and k_state
|
||||||
|
y += linear_factor * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln += (
|
||||||
|
linear_factor
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||||
|
)
|
||||||
|
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
return y, a # attention weights only for the last chunk
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Hybrid window attention linear
|
||||||
|
# ------------------------------
|
||||||
|
def under_window_linear_attention(
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
):
|
||||||
|
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
||||||
|
dtype = f_q.dtype
|
||||||
|
w = window_size
|
||||||
|
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||||
|
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||||
|
qkv = linear_factor * causal_dot_product(
|
||||||
|
f_q.contiguous().to(dtype=torch.float32),
|
||||||
|
f_k.contiguous().to(dtype=torch.float32),
|
||||||
|
v.contiguous().to(dtype=torch.float32),
|
||||||
|
).to(dtype=dtype)
|
||||||
|
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
||||||
|
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
||||||
|
sum_qk[sum_qk == 0] += eps
|
||||||
|
return qkv, sum_qk
|
||||||
|
|
||||||
|
|
||||||
|
def sliding_window_softmax_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute sliding window softmax attention without materializing
|
||||||
|
O(seq_len^2) attention weights
|
||||||
|
"""
|
||||||
|
d = q.shape[-1]
|
||||||
|
# Compute windows for keys
|
||||||
|
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||||
|
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||||
|
|
||||||
|
# Compute windowed_softmax(qk); causal in its construction
|
||||||
|
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
|
||||||
|
a_sm[a_sm == 0] = -torch.finfo(
|
||||||
|
q.dtype
|
||||||
|
).max # heuristic for zeroing out padding above
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
|
||||||
|
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_linear(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: Optional[torch.Tensor] = None,
|
||||||
|
linear_factor: Optional[torch.Tensor] = None,
|
||||||
|
window_size: int = 64,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Alternative hybrid attention combining sliding window and linear attentions
|
||||||
|
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
||||||
|
"""
|
||||||
|
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
if window_factor is None:
|
||||||
|
raise ValueError("window_factor must be provided")
|
||||||
|
|
||||||
|
if linear_factor is None:
|
||||||
|
raise ValueError("linear_factor must be provided")
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
with torch.no_grad():
|
||||||
|
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
|
||||||
|
q, k, v, window_size, window_factor, mask_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
qkv_ln, sum_qk_ln = under_window_linear_attention(
|
||||||
|
f_q, f_k, v, window_size, linear_factor, eps
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
||||||
|
return y, None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size: int = 64,
|
||||||
|
decode_window_size: Optional[int] = None,
|
||||||
|
affine_attention_factors: bool = False,
|
||||||
|
init_window_factor: float = 0,
|
||||||
|
train_window_factor: bool = True,
|
||||||
|
state_grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.decode_window_size = (
|
||||||
|
decode_window_size if decode_window_size is not None else window_size
|
||||||
|
)
|
||||||
|
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Determine how we compute attentions
|
||||||
|
self.linear_attention = hybrid_attention_linear
|
||||||
|
self.attention_type = "lolcats_llama_window_sw"
|
||||||
|
# Learnable factor for combining attentions
|
||||||
|
self.affine_attention_factors = affine_attention_factors
|
||||||
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||||
|
if train_window_factor:
|
||||||
|
self.window_factors = nn.Parameter(
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"window_factors",
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||||
|
self.base_inference = False
|
||||||
|
self.state_grad_enabled = state_grad_enabled
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.train_attention and self.base_inference:
|
||||||
|
with torch.no_grad():
|
||||||
|
_y_true = flash_attention_2(
|
||||||
|
self, # self.base_attn,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
use_cache=False,
|
||||||
|
)[0]
|
||||||
|
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||||
|
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
# layer_io = (hidden_states, _y_true) # hack
|
||||||
|
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||||
|
return y_true, layer_io, None
|
||||||
|
|
||||||
|
else:
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||||
|
k
|
||||||
|
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, a_pred = self.linear_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = a_pred
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if (
|
||||||
|
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||||
|
): # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum(
|
||||||
|
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||||
|
) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum(
|
||||||
|
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||||
|
)[..., None]
|
||||||
|
)
|
||||||
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, _ = self.linear_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self.layer_idx,
|
||||||
|
fmap_key_states=f_k,
|
||||||
|
accumulate_in_fp32=True,
|
||||||
|
)
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
_y_true = y_true.transpose(1, 2).contiguous()
|
||||||
|
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
||||||
|
return y_true, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Account for sliding windows
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = False,
|
||||||
|
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||||
|
grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Update KV, K states; and KV cache during training
|
||||||
|
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||||
|
up to sliding window terms
|
||||||
|
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||||
|
up to end of sequence
|
||||||
|
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||||
|
"""
|
||||||
|
if fmap_key_states is None:
|
||||||
|
raise ValueError("fmap_key_states must not be None")
|
||||||
|
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(grad_enabled):
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
# key_states = key_states.float()
|
||||||
|
fmap_key_states = fmap_key_states.float()
|
||||||
|
value_states = value_states.float()
|
||||||
|
|
||||||
|
# Decoding KV state (KV terms up to last window_size)
|
||||||
|
decode_kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, : -self.window_size],
|
||||||
|
value_states[:, :, : -self.window_size],
|
||||||
|
)
|
||||||
|
# KV state
|
||||||
|
kv_state = decode_kv_state + torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, -self.window_size :],
|
||||||
|
value_states[:, :, -self.window_size :],
|
||||||
|
)
|
||||||
|
# shape is b, h, 1, f; note the 1
|
||||||
|
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
|
||||||
|
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||||
|
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||||
|
|
||||||
|
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(
|
||||||
|
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||||
|
)
|
||||||
|
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||||
|
else:
|
||||||
|
# Update kv and k states recurrently
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
|
||||||
|
decode_kv_state = (
|
||||||
|
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||||
|
+ decode_kv_state
|
||||||
|
).to(dtype)
|
||||||
|
decode_k_state = (
|
||||||
|
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||||
|
).to(dtype)
|
||||||
|
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||||
|
self.decode_k_states[layer_idx] = decode_k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||||
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self,
|
||||||
|
keys: torch.Tensor,
|
||||||
|
values: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
feature_map_k: Callable,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the decoding KV and K states, and KV cache, during decodeing
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
|
||||||
|
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||||
|
else:
|
||||||
|
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||||
|
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||||
|
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||||
|
# else:
|
||||||
|
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||||
|
).to(
|
||||||
|
dtype
|
||||||
|
) # b, h, f, d
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat(
|
||||||
|
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||||
|
)
|
||||||
|
self.v_cache[layer_idx] = torch.cat(
|
||||||
|
[v_cache[:, :, 1:, :], values], dim=-2
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += keys.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Flash Attention 2
|
||||||
|
# -----------------
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_2(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper for LlamaFlashAttention2
|
||||||
|
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||||
|
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||||
|
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||||
|
"""
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
try: # As in Transformers v4.36
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
except Exception: # As in Transformers v4.39
|
||||||
|
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
if getattr(self, "_flash_attention_forward", False):
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=0, # dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
return attn_output, past_key_value
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
LoLCATs attention combining sliding window and linear attentions
|
||||||
|
- Using standard sliding window arrangement
|
||||||
|
- Training over long sequences with fixed memory with recurrent view
|
||||||
|
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
from .linear_window_attention_sw import hybrid_attention_quadratic
|
||||||
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, remove_base_attn=True, **kwargs):
|
||||||
|
# keep self.base_attn for Flash Attention inference
|
||||||
|
super().__init__(remove_base_attn=True, **kwargs)
|
||||||
|
self.quadratic_attention = hybrid_attention_quadratic
|
||||||
@@ -0,0 +1,466 @@
|
|||||||
|
"""
|
||||||
|
Subquadratic attention combining sliding window and linear attentions
|
||||||
|
- Using the TK "terracing" arrangement
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
from .linear_attention import (
|
||||||
|
LinearAttentionState,
|
||||||
|
LolcatsLinearAttention,
|
||||||
|
softmax_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Sliding window helpers
|
||||||
|
# ----------------------
|
||||||
|
def get_masks(
|
||||||
|
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Return masks for softmax and linear attention terms
|
||||||
|
-> 1 is include, 0 is ignore
|
||||||
|
"""
|
||||||
|
win_len = window_size
|
||||||
|
m = math.ceil(max(q_len, k_len) / window_size)
|
||||||
|
# Creates an n x n mask where n = window_size^2
|
||||||
|
mask = torch.block_diag(
|
||||||
|
*[
|
||||||
|
torch.ones(
|
||||||
|
(win_len, win_len),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
* m
|
||||||
|
)
|
||||||
|
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
|
||||||
|
if mask.shape[0] > q_len:
|
||||||
|
mask = mask[-q_len:]
|
||||||
|
if mask.shape[1] > k_len:
|
||||||
|
mask = mask[:, -k_len:]
|
||||||
|
# Return softmax mask (window), linear attention mask
|
||||||
|
mask = mask[None, None, ...] # b, h, q_len, k_len
|
||||||
|
return (
|
||||||
|
torch.tril(mask).to(device=device, dtype=torch.int),
|
||||||
|
torch.tril(1 - mask).to(device=device, dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_quadratic(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hybrid attention combining sliding window and linear attentions
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask_window, mask_linear = get_masks(
|
||||||
|
window_size, q.shape[-2], k.shape[-2], q.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||||
|
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||||
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||||
|
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||||
|
# Allow outputs to also depend on prior kv_state and k_state
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||||
|
if (
|
||||||
|
kv_state is not None and k_state is not None
|
||||||
|
): # Combine with prior kv_state and k_state
|
||||||
|
y += linear_factor * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln += (
|
||||||
|
linear_factor
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||||
|
)
|
||||||
|
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
return y, a # attention weights only for the last chunk
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
class LolcatsTKWindowAttention(LolcatsLinearAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size: int = 64,
|
||||||
|
decode_window_size: Optional[int] = None,
|
||||||
|
affine_attention_factors: bool = False,
|
||||||
|
init_window_factor: float = 0,
|
||||||
|
train_window_factor: bool = True,
|
||||||
|
state_grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.decode_window_size = (
|
||||||
|
decode_window_size if decode_window_size is not None else window_size
|
||||||
|
)
|
||||||
|
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
|
||||||
|
# Determine how we compute attentions
|
||||||
|
self.quadratic_attention = hybrid_attention_quadratic
|
||||||
|
self.attention_type = kwargs[
|
||||||
|
"attention_type"
|
||||||
|
] # 'hedgehog_long_llama_window_tk'
|
||||||
|
# Learnable factor for combining attentions
|
||||||
|
self.affine_attention_factors = affine_attention_factors
|
||||||
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||||
|
if train_window_factor:
|
||||||
|
self.window_factors = nn.Parameter(
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"window_factors",
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||||
|
self.base_inference = False
|
||||||
|
self.state_grad_enabled = state_grad_enabled
|
||||||
|
self.window_factor = self.window_factors # legacy naming support
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||||
|
k
|
||||||
|
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
with torch.no_grad():
|
||||||
|
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||||
|
y_true = (
|
||||||
|
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
|
||||||
|
# 2. Compute "predicted" attention outputs
|
||||||
|
# compute attn weights under sliding window
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||||
|
else:
|
||||||
|
attn_weights = None
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = a_pred
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if (
|
||||||
|
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||||
|
): # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum(
|
||||||
|
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||||
|
) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum(
|
||||||
|
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
|
||||||
|
)[..., None]
|
||||||
|
)
|
||||||
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, _ = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self.layer_idx,
|
||||||
|
fmap_key_states=f_k,
|
||||||
|
accumulate_in_fp32=True,
|
||||||
|
)
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
return y_true, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionTKWindowCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Account for sliding windows
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = False,
|
||||||
|
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||||
|
grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Update KV, K states; and KV cache during training
|
||||||
|
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||||
|
up to sliding window terms
|
||||||
|
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||||
|
up to end of sequence
|
||||||
|
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||||
|
"""
|
||||||
|
if fmap_key_states is None:
|
||||||
|
raise ValueError("fmap_key_states should not be None")
|
||||||
|
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("layer_idx should not be None")
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(grad_enabled):
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
# key_states = key_states.float()
|
||||||
|
fmap_key_states = fmap_key_states.float()
|
||||||
|
value_states = value_states.float()
|
||||||
|
|
||||||
|
# Decoding KV state (KV terms up to last window_size)
|
||||||
|
decode_kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, : -self.window_size],
|
||||||
|
value_states[:, :, : -self.window_size],
|
||||||
|
)
|
||||||
|
# KV state
|
||||||
|
kv_state = decode_kv_state + torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, -self.window_size :],
|
||||||
|
value_states[:, :, -self.window_size :],
|
||||||
|
)
|
||||||
|
# shape is b, h, 1, f; note the 1
|
||||||
|
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
|
||||||
|
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||||
|
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||||
|
|
||||||
|
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(
|
||||||
|
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||||
|
)
|
||||||
|
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||||
|
else:
|
||||||
|
# Update kv and k states recurrently
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
|
||||||
|
decode_kv_state = (
|
||||||
|
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||||
|
+ decode_kv_state
|
||||||
|
).to(dtype)
|
||||||
|
decode_k_state = (
|
||||||
|
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||||
|
).to(dtype)
|
||||||
|
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||||
|
self.decode_k_states[layer_idx] = decode_k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||||
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self,
|
||||||
|
keys: torch.Tensor,
|
||||||
|
values: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
feature_map_k: Callable,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the decoding KV and K states, and KV cache, during decodeing
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
|
||||||
|
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||||
|
else:
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||||
|
).to(
|
||||||
|
dtype
|
||||||
|
) # b, h, f, d
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat(
|
||||||
|
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||||
|
)
|
||||||
|
self.v_cache[layer_idx] = torch.cat(
|
||||||
|
[v_cache[:, :, 1:, :], values], dim=-2
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += keys.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
"""
|
||||||
|
LoLCATs + ThunderKittens linear attention + sliding window for generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .linear_attention import LinearAttentionState
|
||||||
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
||||||
|
|
||||||
|
LOG.debug("Successfully imported ThunderKittens for TK window attention")
|
||||||
|
except ImportError:
|
||||||
|
LOG.debug("Failed to import ThunderKittens for TK window attention")
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
|
||||||
|
def __init__(self, *args, window_size: int = 64, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.train_attention = False
|
||||||
|
self.base_inference = False
|
||||||
|
self.window_size = 64 # hard-coded support for TK kernel
|
||||||
|
self.decode_window_size = 64
|
||||||
|
|
||||||
|
b, h, l, d = 1, 32, 8192, 128
|
||||||
|
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
|
||||||
|
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
|
||||||
|
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Any] = None, # “legacy” cache approach
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
assert (
|
||||||
|
past_key_value is not None
|
||||||
|
), "past_key_value must be provided for generation"
|
||||||
|
assert (
|
||||||
|
self.train_attention is False
|
||||||
|
), "train_attention is not supported for generation"
|
||||||
|
assert (
|
||||||
|
self.base_inference is False
|
||||||
|
), "base_inference is not supported for generation"
|
||||||
|
assert use_cache is True, "use_cache must be True for generation"
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
|
||||||
|
f_q = self.feature_map_q(q)
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k
|
||||||
|
)
|
||||||
|
k_cache, v_cache, kv_state, k_state = _kv
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||||
|
k.shape[-1] ** -0.5
|
||||||
|
)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Process prefill
|
||||||
|
# Use TK-implemented linear + terrace window attention
|
||||||
|
b, h, l, d = q.shape
|
||||||
|
device = q.device
|
||||||
|
# tk.hedgehog arguments
|
||||||
|
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
|
||||||
|
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
|
||||||
|
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
|
||||||
|
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
|
||||||
|
alphas = (
|
||||||
|
1 - betas
|
||||||
|
if self.affine_attention_factors
|
||||||
|
else torch.ones(betas.shape, dtype=torch.float32, device=device)
|
||||||
|
)
|
||||||
|
q_map = self.feature_map_q.mlp.layer
|
||||||
|
k_map = self.feature_map_k.mlp.layer
|
||||||
|
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
|
||||||
|
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||||
|
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
|
||||||
|
# 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’,
|
||||||
|
# f_k[:, :, :-self.window_size],
|
||||||
|
# v[:, :, :-self.window_size]) # b, h, f, d
|
||||||
|
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
|
||||||
|
|
||||||
|
tk_window_hedgehog_attention(
|
||||||
|
q.contiguous(),
|
||||||
|
k.contiguous(),
|
||||||
|
v.contiguous(),
|
||||||
|
self.y_true,
|
||||||
|
self.k_state,
|
||||||
|
self.kv_state,
|
||||||
|
q_map,
|
||||||
|
k_map,
|
||||||
|
alphas,
|
||||||
|
betas,
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_value.update_with_kv(
|
||||||
|
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
return y_true, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
def update_with_kv(
|
||||||
|
self,
|
||||||
|
kv_state: torch.Tensor,
|
||||||
|
k_state: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the cache with new KV and K states
|
||||||
|
"""
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += k.shape[2]
|
||||||
|
self._seen_tokens_by_layer.append(k.shape[2])
|
||||||
|
|
||||||
|
# Initialize KV and K states
|
||||||
|
if len(self.decode_k_states) <= layer_idx:
|
||||||
|
self.decode_kv_states.append(kv_state)
|
||||||
|
self.decode_k_states.append(k_state)
|
||||||
|
else: # Update KV and K states
|
||||||
|
self.decode_kv_states[layer_idx] = (
|
||||||
|
self.decode_kv_states[layer_idx] + kv_state
|
||||||
|
)
|
||||||
|
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
|
||||||
|
|
||||||
|
self.k_cache.append(k[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(v[:, :, -self.window_size :, :])
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the cache for decoding
|
||||||
|
"""
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
|
||||||
|
k.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += k.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
@@ -0,0 +1,306 @@
|
|||||||
|
"""
|
||||||
|
LoLCATs attention combining sliding window and linear attentions
|
||||||
|
- Using the TK "terracing" arrangement
|
||||||
|
- Training over long sequences with fixed memory with recurrent view
|
||||||
|
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from .linear_attention import softmax_attention
|
||||||
|
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||||
|
|
||||||
|
LOG = logging.getLogger(
|
||||||
|
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, remove_base_attn=True, **kwargs):
|
||||||
|
# keep self.base_attn for Flash Attention inference
|
||||||
|
super().__init__(remove_base_attn=True, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
if self.train_attention and self.base_inference:
|
||||||
|
with torch.no_grad():
|
||||||
|
# LOG.debug(hidden_states.shape)
|
||||||
|
_y_true = flash_attention_2(
|
||||||
|
self, # self.base_attn,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
# output_hidden_states=False,
|
||||||
|
use_cache=False,
|
||||||
|
)[0]
|
||||||
|
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||||
|
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
layer_io = (hidden_states, _y_true) # hack
|
||||||
|
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||||
|
return y_true, layer_io, None
|
||||||
|
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||||
|
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||||
|
k.shape[-1] ** -0.5
|
||||||
|
)
|
||||||
|
# a_sm = torch.softmax(a_sm, dim=-1)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
y_pred = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
if (
|
||||||
|
self.state_grad_enabled
|
||||||
|
and self.layer_idx == 0
|
||||||
|
and position_ids is not None
|
||||||
|
):
|
||||||
|
LOG.debug(
|
||||||
|
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
_y_pred = y_pred.transpose(1, 2).contiguous()
|
||||||
|
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
with torch.no_grad():
|
||||||
|
a_true = softmax_attention(q, k, None, causal=True)[1]
|
||||||
|
attn_weights = (_y_pred, (a_pred, a_true))
|
||||||
|
else:
|
||||||
|
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
|
||||||
|
return y_pred, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Flash Attention 2
|
||||||
|
# -----------------
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_2(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper for LlamaFlashAttention2
|
||||||
|
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||||
|
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||||
|
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||||
|
"""
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
try: # As in Transformers v4.36
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
except Exception: # As in Transformers v4.39
|
||||||
|
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
if getattr(self, "_flash_attention_forward", False):
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=0, # dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
return attn_output, past_key_value
|
||||||
@@ -0,0 +1,361 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
"""Linear LLaMA model implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .configuration_linear_llama import LinearLlamaConfig
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
||||||
|
"""
|
||||||
|
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
|
||||||
|
# Replace the attention layer with our custom attention
|
||||||
|
self.self_attn = convert_llama_attention(
|
||||||
|
layer=self, attention_config=config.attention_config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLlamaModel(LlamaModel):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: LinearLlamaConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LinearLlamaConfig
|
||||||
|
base_model_prefix = "linear_llama"
|
||||||
|
|
||||||
|
def __init__(self, config: LinearLlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(
|
||||||
|
config.vocab_size, config.hidden_size, self.padding_idx
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
LinearLlamaDecoderLayer(config, layer_idx)
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
"""
|
||||||
|
Linear LLaMA model for causal language modeling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LinearLlamaConfig
|
||||||
|
base_model_prefix = "linear_llama"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = LinearLlamaModel(config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llama(
|
||||||
|
cls,
|
||||||
|
model: LlamaForCausalLM,
|
||||||
|
config: LinearLlamaConfig,
|
||||||
|
train_attention: bool = False,
|
||||||
|
remove_base_attn: bool = True,
|
||||||
|
) -> "LinearLlamaForCausalLM":
|
||||||
|
"""
|
||||||
|
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
||||||
|
"""
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
raise ValueError("Missing config")
|
||||||
|
|
||||||
|
# initialize a new model with config
|
||||||
|
new_model = cls(config=config)
|
||||||
|
|
||||||
|
# remove the default model and lm_head
|
||||||
|
del new_model.model
|
||||||
|
del new_model.lm_head
|
||||||
|
|
||||||
|
# load converted model, lm_head, and vocab_size from llama model
|
||||||
|
new_model.model = convert_attention(
|
||||||
|
model.model,
|
||||||
|
attention_config=config.attention_config,
|
||||||
|
train_attention=train_attention,
|
||||||
|
remove_base_attn=remove_base_attn,
|
||||||
|
)
|
||||||
|
new_model.lm_head = model.lm_head
|
||||||
|
new_model.vocab_size = model.vocab_size
|
||||||
|
|
||||||
|
return new_model
|
||||||
|
|
||||||
|
def toggle_attention(self, train: bool = True):
|
||||||
|
"""
|
||||||
|
Toggle attention to be trainable or not
|
||||||
|
"""
|
||||||
|
|
||||||
|
toggle_attention(self.model, train=train)
|
||||||
|
|
||||||
|
def remove_base_attention(self):
|
||||||
|
"""
|
||||||
|
Remove base attention after distillation
|
||||||
|
"""
|
||||||
|
|
||||||
|
remove_base_attention(self.model)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_attention(
|
||||||
|
model: nn.Module,
|
||||||
|
attention_config: dict,
|
||||||
|
train_attention: bool = False,
|
||||||
|
remove_base_attn: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Call to convert all attention layers
|
||||||
|
"""
|
||||||
|
# Get the layers to convert if provided
|
||||||
|
softmax_attns = attention_config.get("softmax_attentions", [])
|
||||||
|
|
||||||
|
# Get the attention to convert to
|
||||||
|
attention_type = attention_config.get("attention_type")
|
||||||
|
|
||||||
|
if attention_type != "softmax":
|
||||||
|
layers = traverse_layers(model)
|
||||||
|
for layer_idx, layer in enumerate(
|
||||||
|
tqdm(layers, desc="Converting attentions...")
|
||||||
|
):
|
||||||
|
if layer_idx not in softmax_attns:
|
||||||
|
layer.self_attn = convert_llama_attention(
|
||||||
|
layer,
|
||||||
|
attention_config,
|
||||||
|
layers,
|
||||||
|
train_attention,
|
||||||
|
remove_base_attn,
|
||||||
|
)
|
||||||
|
layer.self_attn.converted = True
|
||||||
|
else:
|
||||||
|
# Freeze any preserved softmax attention layers
|
||||||
|
for p in layer.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
||||||
|
"""
|
||||||
|
Make attentions trainable if train is True
|
||||||
|
-> Set train_attention = False when finetuning
|
||||||
|
"""
|
||||||
|
for layer in traverse_layers(llama_model):
|
||||||
|
layer.self_attn.train_attention = train
|
||||||
|
return llama_model
|
||||||
|
|
||||||
|
|
||||||
|
def remove_base_attention(llama_model: nn.Module):
|
||||||
|
"""
|
||||||
|
Remove teacher attention after distillation (if we keep it)
|
||||||
|
"""
|
||||||
|
for layer in traverse_layers(llama_model):
|
||||||
|
if getattr(layer.self_attn, "base_attn", False):
|
||||||
|
del layer.self_attn.base_attn
|
||||||
|
return llama_model
|
||||||
|
|
||||||
|
|
||||||
|
def traverse_layers(model: nn.Module, verbose: bool = False):
|
||||||
|
"""
|
||||||
|
Return list of model layers
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
layers = model.model.layers
|
||||||
|
if verbose:
|
||||||
|
LOG.info("-> Loading from model.model.layers")
|
||||||
|
except AttributeError as e: # if base model
|
||||||
|
if verbose:
|
||||||
|
LOG.info(e)
|
||||||
|
try:
|
||||||
|
layers = model.layers
|
||||||
|
if verbose:
|
||||||
|
LOG.info("-> Loading from model.layers")
|
||||||
|
except AttributeError as e1: # If we make a PEFT model
|
||||||
|
if verbose:
|
||||||
|
LOG.info(e1)
|
||||||
|
layers = model.base_model.model.model.layers
|
||||||
|
if verbose:
|
||||||
|
LOG.info("-> Loading from model.base_model.model.model.layers")
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
def convert_llama_attention(
|
||||||
|
layer: nn.Module,
|
||||||
|
attention_config: dict,
|
||||||
|
layers: Optional[list[nn.Module]] = None, # list of layers
|
||||||
|
train_attention: bool = False,
|
||||||
|
remove_base_attn: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converts a single layer's attention layer as specified by attention_config
|
||||||
|
"""
|
||||||
|
return get_attention(**attention_config)(
|
||||||
|
base_attn=layer.self_attn,
|
||||||
|
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||||
|
max_layer_idx=len(layers) - 1 if layers else None,
|
||||||
|
train_attention=train_attention,
|
||||||
|
remove_base_attn=remove_base_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention(attention_type: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the linear attention class; either purely linear or linear with sliding window
|
||||||
|
-> 'linear' == 'lolcats_llama'
|
||||||
|
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
||||||
|
"""
|
||||||
|
kwargs["attention_type"] = attention_type
|
||||||
|
|
||||||
|
if attention_type == "lolcats_llama":
|
||||||
|
from .linear_attention import LolcatsLinearAttention
|
||||||
|
|
||||||
|
return partial(LolcatsLinearAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_llama_window_tk":
|
||||||
|
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||||
|
|
||||||
|
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_llama_window_sw":
|
||||||
|
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
|
||||||
|
|
||||||
|
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||||
|
from .linear_window_attention_sw_linear import (
|
||||||
|
LolcatsLinearSlidingWindowAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
|
# Experimental chunked linear attentions below
|
||||||
|
elif attention_type == "lolcats_long_llama_window_tk":
|
||||||
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
|
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_long_llama_window_sw":
|
||||||
|
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||||
|
|
||||||
|
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
|
# TK generation build (requires Thunderkittens)
|
||||||
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
|
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
|
||||||
|
|
||||||
|
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||||
|
"""
|
||||||
|
Determine how we store past keys and values when generating
|
||||||
|
"""
|
||||||
|
if attention_type is None:
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||||
|
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||||
|
from .linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
|
elif "llama_window_tk" in attention_type:
|
||||||
|
from .linear_window_attention_tk import LinearAttentionTKWindowCache
|
||||||
|
|
||||||
|
return LinearAttentionTKWindowCache()
|
||||||
|
|
||||||
|
elif "llama_window_sw" in attention_type:
|
||||||
|
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
|
elif "llama_window_sw_linear" in attention_type:
|
||||||
|
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
|
# TK generation build (requires Thunderkittens)
|
||||||
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
|
from .linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
|
elif "softmax" in attention_type:
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
else:
|
||||||
|
from .linear_attention import LinearAttentionState
|
||||||
|
|
||||||
|
return LinearAttentionState()
|
||||||
|
|
||||||
|
|
||||||
|
def register_linear_llama():
|
||||||
|
"""
|
||||||
|
Register Linear LLaMA model with the Transformers library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
|
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||||
|
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||||
|
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||||
|
|
||||||
|
# registering for auto classes to save files
|
||||||
|
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
||||||
|
LinearLlamaModel.register_for_auto_class("AutoModel")
|
||||||
|
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
|
||||||
|
|
||||||
|
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from torch import Tensor, nn, tensor
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Custom trainer class for distilling attentions.
|
||||||
|
- We compute and store the attention outputs and/or weights for each head and layer,
|
||||||
|
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
|
||||||
|
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
mse_factor: float = 1e3,
|
||||||
|
xent_factor: float = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__(model=model, **kwargs)
|
||||||
|
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
||||||
|
self.criterion_mse = nn.MSELoss(reduction="mean")
|
||||||
|
self.mse_factor = mse_factor
|
||||||
|
self.xent_factor = xent_factor
|
||||||
|
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
|
||||||
|
|
||||||
|
self.model_accepts_loss_kwargs = False # added to combat explosive loss
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, Tensor],
|
||||||
|
return_outputs=False,
|
||||||
|
num_items_in_batch=None,
|
||||||
|
) -> tuple[Tensor, dict]:
|
||||||
|
"""
|
||||||
|
Attention distillation ("attention transfer")
|
||||||
|
- For each layer and head, get attentions and train to
|
||||||
|
minimize some combo of MSE and cross-entropy loss
|
||||||
|
"""
|
||||||
|
# alias inputs to data
|
||||||
|
data = inputs
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
# Filter out labels
|
||||||
|
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
||||||
|
|
||||||
|
# set num_items_in_batch
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss_kwargs = {}
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||||
|
inputs = {**inputs, **loss_kwargs}
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||||
|
outputs = outputs.get("attentions")
|
||||||
|
|
||||||
|
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
# n_layers x (predicted_attns, true_attns)
|
||||||
|
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
||||||
|
loss_mse = tensor(0.0, device=device)
|
||||||
|
loss_xent = tensor(0.0, device=device)
|
||||||
|
n_layers = 0 # Number of layers to distill
|
||||||
|
softmax_layers = []
|
||||||
|
for layer_idx, attns in enumerate(outputs):
|
||||||
|
if attns is not None:
|
||||||
|
if len(attns) != 2:
|
||||||
|
attns = attns.cpu()
|
||||||
|
else:
|
||||||
|
if self.xent_factor > 0:
|
||||||
|
# Cross-entropy loss
|
||||||
|
a_pred, a_true = attns[0]
|
||||||
|
a_pred = a_pred.clamp(
|
||||||
|
min=1e-12
|
||||||
|
).log() # nn.CrossEntropy assumes unnormalized logits
|
||||||
|
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
|
||||||
|
# Compute mean cross-entropy over all queries
|
||||||
|
a_pred = a_pred.contiguous().view(-1, k_len)
|
||||||
|
a_true = a_true.contiguous().view(-1, k_len)
|
||||||
|
loss_xent += self.criterion_xent(a_pred, a_true)
|
||||||
|
if self.mse_factor > 0:
|
||||||
|
loss_mse += self.criterion_mse(*attns[1])
|
||||||
|
n_layers += 1
|
||||||
|
else:
|
||||||
|
softmax_layers.append(layer_idx)
|
||||||
|
if n_layers > 0:
|
||||||
|
loss_xent = loss_xent / n_layers * self.xent_factor
|
||||||
|
loss_mse = loss_mse / n_layers * self.mse_factor
|
||||||
|
loss = loss_xent + loss_mse
|
||||||
|
|
||||||
|
if "position_ids" in data:
|
||||||
|
outputs = {
|
||||||
|
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||||
|
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||||
|
"input_len": data["position_ids"].shape[1],
|
||||||
|
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
|
||||||
|
"mse_factor": self.mse_factor,
|
||||||
|
"xent_factor": self.xent_factor,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
outputs = {
|
||||||
|
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||||
|
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||||
|
"mse_factor": self.mse_factor,
|
||||||
|
"xent_factor": self.xent_factor,
|
||||||
|
}
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
Reference in New Issue
Block a user