Compare commits
35 Commits
feat/linea
...
grpo-path
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9d842ef2e | ||
|
|
ecea44c902 | ||
|
|
4f9c57e95d | ||
|
|
3d38bc82b8 | ||
|
|
756a8332d6 | ||
|
|
aded9c500d | ||
|
|
3659d812f7 | ||
|
|
bdb0f97082 | ||
|
|
65b6519447 | ||
|
|
a1958b09de | ||
|
|
b8f258817e | ||
|
|
753146b458 | ||
|
|
d683c50113 | ||
|
|
234cd8311e | ||
|
|
f9893e3842 | ||
|
|
ac1ebc58a8 | ||
|
|
56f3b9f20f | ||
|
|
2c1376d8c4 | ||
|
|
3c7517fd55 | ||
|
|
1e94d7ef65 | ||
|
|
cfc7fe0df2 | ||
|
|
3c4fe478cf | ||
|
|
c810599c66 | ||
|
|
300ffc2cb6 | ||
|
|
b1c4711145 | ||
|
|
d155849e2c | ||
|
|
626db6cb84 | ||
|
|
79159b4871 | ||
|
|
704ddd6ff1 | ||
|
|
54b0d3d0e8 | ||
|
|
59ad21f2de | ||
|
|
57264b6491 | ||
|
|
d495e41ba1 | ||
|
|
6067fe6c28 | ||
|
|
a620d481e2 |
@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.48.1
|
transformers==4.48.2
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.0
|
||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.13.0
|
trl==0.14.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
@@ -26,7 +26,7 @@ sentencepiece
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
|
|
||||||
modal==0.70.5
|
modal==0.70.5
|
||||||
pydantic==2.6.3
|
pydantic==2.10.6
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -153,5 +153,8 @@ setup(
|
|||||||
"ray": [
|
"ray": [
|
||||||
"ray[train]",
|
"ray[train]",
|
||||||
],
|
],
|
||||||
|
"vllm": [
|
||||||
|
"vllm>=0.7.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,8 +22,11 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
|
|||||||
|
|
||||||
# modal workaround so it doesn't use the automounted axolotl
|
# modal workaround so it doesn't use the automounted axolotl
|
||||||
new_env = copy.deepcopy(os.environ)
|
new_env = copy.deepcopy(os.environ)
|
||||||
if "PYTHONPATH" in new_env:
|
# if "PYTHONPATH" in new_env:
|
||||||
del new_env["PYTHONPATH"]
|
# python_path = Path(new_env["PYTHONPATH"].split(":")[0])
|
||||||
|
# if python_path.joinpath("src", "axolotl").exists():
|
||||||
|
# # we don't want to use the automounted axolotl or unexpected behavior happens
|
||||||
|
# del new_env["PYTHONPATH"]
|
||||||
|
|
||||||
# Propagate errors from subprocess.
|
# Propagate errors from subprocess.
|
||||||
if exit_code := subprocess.call( # nosec B603
|
if exit_code := subprocess.call( # nosec B603
|
||||||
|
|||||||
@@ -1,135 +0,0 @@
|
|||||||
"""CLI to run training on a model."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import fire
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
|
||||||
from axolotl.cli.config import load_cfg
|
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.integrations.base import PluginManager
|
|
||||||
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
|
|
||||||
LinearLlamaConfig,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
|
||||||
LinearLlamaForCausalLM,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.models import load_model_config
|
|
||||||
from axolotl.utils.trainer import setup_trainer
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|
||||||
"""
|
|
||||||
Convert attention to linear attention and perform attention transfer via distillation.
|
|
||||||
"""
|
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
|
||||||
check_user_token()
|
|
||||||
|
|
||||||
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
|
|
||||||
cfg.load_in_8bit = False
|
|
||||||
cfg.load_in_4bit = False
|
|
||||||
cfg.adapter = None
|
|
||||||
|
|
||||||
# load model
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
|
||||||
|
|
||||||
# freeze model
|
|
||||||
for p in model.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
# convert to linear llama
|
|
||||||
linear_llama_config = LinearLlamaConfig.from_llama(
|
|
||||||
model.config, cfg.attention_config
|
|
||||||
)
|
|
||||||
model = LinearLlamaForCausalLM.from_llama(
|
|
||||||
model, config=linear_llama_config, train_attention=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# set save_path, save tokenizer and model config.
|
|
||||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
|
||||||
tokenizer.save_pretrained(save_path)
|
|
||||||
if hasattr(model, "config"):
|
|
||||||
model.config.save_pretrained(save_path)
|
|
||||||
|
|
||||||
# Get datasets
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
train_dataset = dataset_meta.train_dataset
|
|
||||||
eval_dataset = dataset_meta.eval_dataset
|
|
||||||
total_num_steps = dataset_meta.total_num_steps
|
|
||||||
|
|
||||||
# toggle attention to be trainable
|
|
||||||
model.toggle_attention(train=True)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
trainer = setup_trainer(
|
|
||||||
cfg=cfg,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
model=(model, None, None),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
processor=None,
|
|
||||||
total_num_steps=total_num_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
# train
|
|
||||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
|
||||||
|
|
||||||
# drop base_attention + remove training attn
|
|
||||||
model.toggle_attention(train=False)
|
|
||||||
model.remove_base_attention()
|
|
||||||
|
|
||||||
# NOTE: If in peft mode, consider whether to auto-merge
|
|
||||||
|
|
||||||
# save model
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
|
||||||
model.save_pretrained(save_path, safe_serialization=safe_serialization)
|
|
||||||
|
|
||||||
# cleanup
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
|
|
||||||
del model
|
|
||||||
del tokenizer
|
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Path to `axolotl` config YAML file.
|
|
||||||
kwargs: Additional keyword arguments to override config file values.
|
|
||||||
"""
|
|
||||||
# load cfg, force linearize and add plugin to linearize
|
|
||||||
parsed_cfg = load_cfg(
|
|
||||||
config,
|
|
||||||
linearize=True,
|
|
||||||
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser = HfArgumentParser(TrainerCliArgs)
|
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
|
||||||
return_remaining_strings=True
|
|
||||||
)
|
|
||||||
|
|
||||||
do_linearize(parsed_cfg, parsed_cli_args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
|
||||||
@@ -12,6 +12,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
import yaml
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
@@ -208,7 +209,7 @@ def train(
|
|||||||
accelerate_args.append(str(main_process_port))
|
accelerate_args.append(str(main_process_port))
|
||||||
if "num_processes" in kwargs:
|
if "num_processes" in kwargs:
|
||||||
num_processes = kwargs.pop("num_processes", None)
|
num_processes = kwargs.pop("num_processes", None)
|
||||||
accelerate_args.append("--num-processes")
|
accelerate_args.append("--num_processes")
|
||||||
accelerate_args.append(str(num_processes))
|
accelerate_args.append(str(num_processes))
|
||||||
|
|
||||||
base_cmd = ["accelerate", "launch"]
|
base_cmd = ["accelerate", "launch"]
|
||||||
@@ -381,4 +382,5 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlDPOTrainer,
|
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlDPOConfig,
|
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
AxolotlORPOConfig,
|
AxolotlORPOConfig,
|
||||||
AxolotlPRMConfig,
|
AxolotlPRMConfig,
|
||||||
@@ -652,7 +653,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
@@ -965,10 +966,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
if self.cfg.dataset_processes:
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
if self.cfg.rl_beta:
|
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
@@ -977,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
training_args_cls = AxolotlCPOConfig
|
training_args_cls = AxolotlCPOConfig
|
||||||
training_args_kwargs["loss_type"] = "simpo"
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
@@ -1001,11 +1004,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kto_undesirable_weight or 1.0
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
elif self.cfg.rl == "grpo":
|
||||||
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
|
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rl == "ipo":
|
if self.cfg.rl == "ipo":
|
||||||
@@ -1016,9 +1023,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||||
|
training_args_kwargs[
|
||||||
|
"use_logits_to_keep"
|
||||||
|
] = self.cfg.dpo_use_logits_to_keep
|
||||||
|
|
||||||
|
for blocklist_key in blocklist_args_kwargs:
|
||||||
|
if blocklist_key in training_args_kwargs:
|
||||||
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
@@ -1047,8 +1062,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl == "grpo":
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
|
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||||
|
trainer_cls = DPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
@@ -1068,7 +1087,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
if self.cfg.datasets is not None and (
|
||||||
|
trainer_cls is DPOStrategy.get_trainer_class()
|
||||||
|
):
|
||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,30 +5,21 @@ module for customized trainers
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import (
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
CPOTrainer,
|
|
||||||
DPOTrainer,
|
|
||||||
KTOTrainer,
|
|
||||||
ORPOTrainer,
|
|
||||||
PRMTrainer,
|
|
||||||
RewardTrainer,
|
|
||||||
)
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
|
||||||
|
|
||||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dataset_tags = dataset_tags
|
|
||||||
self.optimizer = None
|
|
||||||
self.model_accepts_loss_kwargs = False
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if self.args.loraplus_lr_ratio is None:
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args,
|
|
||||||
opt_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
if loraplus_lr_ratio:
|
|
||||||
print("Using lora+")
|
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
@wraps(DPOTrainer.push_to_hub)
|
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
|
||||||
"""
|
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
|
||||||
)
|
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tokenize_row(
|
|
||||||
features,
|
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
) -> Dict:
|
|
||||||
res = DPOTrainer.tokenize_row(
|
|
||||||
features,
|
|
||||||
processing_class,
|
|
||||||
max_prompt_length,
|
|
||||||
max_completion_length,
|
|
||||||
add_special_tokens,
|
|
||||||
)
|
|
||||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
|
||||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
|
||||||
for key in res.keys():
|
|
||||||
res[key] = res[key][1:]
|
|
||||||
|
|
||||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
|
||||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
|
||||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
|
||||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
|
||||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
|
||||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
|
||||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
|
||||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def training_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
||||||
num_items_in_batch=None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
|||||||
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
DPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class DPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for DPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlDPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
|
|
||||||
|
return AxolotlDPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
if cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||||
|
if cfg.dpo_use_weighting is not None:
|
||||||
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
|
return training_args_kwargs
|
||||||
15
src/axolotl/core/trainers/dpo/args.py
Normal file
15
src/axolotl/core/trainers/dpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Axolotl specific DPO args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trl import DPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
|
"""
|
||||||
|
DPO config for DPO training
|
||||||
|
"""
|
||||||
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
DPO trainer for axolotl
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Trainer
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import (
|
||||||
|
SchedulerMixin,
|
||||||
|
_sanitize_kwargs_for_ds_tagging,
|
||||||
|
_sanitize_kwargs_for_tagging,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self.optimizer = None
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if self.args.loraplus_lr_ratio is None:
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
if loraplus_lr_ratio:
|
||||||
|
print("Using lora+")
|
||||||
|
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
@wraps(DPOTrainer.push_to_hub)
|
||||||
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
|
"""
|
||||||
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
|
)
|
||||||
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tokenize_row(
|
||||||
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
|
) -> Dict:
|
||||||
|
res = DPOTrainer.tokenize_row(
|
||||||
|
features,
|
||||||
|
processing_class,
|
||||||
|
max_prompt_length,
|
||||||
|
max_completion_length,
|
||||||
|
add_special_tokens,
|
||||||
|
)
|
||||||
|
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||||
|
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||||
|
for key in res.keys():
|
||||||
|
res[key] = res[key][1:]
|
||||||
|
|
||||||
|
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||||
|
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||||
|
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||||
|
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||||
|
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||||
|
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||||
|
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||||
|
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
num_items_in_batch=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return loss
|
||||||
113
src/axolotl/core/trainers/grpo/__init__.py
Normal file
113
src/axolotl/core/trainers/grpo/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
GRPO Specific Strategy for training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class GRPOStrategy:
|
||||||
|
"""
|
||||||
|
Strategy for GRPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trainer_class(cls):
|
||||||
|
return AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_training_args_class(cls):
|
||||||
|
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||||
|
|
||||||
|
return AxolotlGRPOConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_training_args_kwargs(cls, cfg):
|
||||||
|
grpo_args_kwargs = {}
|
||||||
|
if cfg.trl and cfg.trl.use_vllm:
|
||||||
|
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||||
|
if cfg.trl and cfg.trl.vllm_device:
|
||||||
|
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||||
|
else:
|
||||||
|
grpo_args_kwargs["vllm_device"] = "auto"
|
||||||
|
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"vllm_gpu_memory_utilization"
|
||||||
|
] = cfg.trl.vllm_gpu_memory_utilization
|
||||||
|
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||||
|
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||||
|
if cfg.trl and cfg.trl.num_generations:
|
||||||
|
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||||
|
if cfg.trl and cfg.trl.sync_ref_model:
|
||||||
|
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||||
|
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||||
|
grpo_args_kwargs[
|
||||||
|
"ref_model_mixup_alpha"
|
||||||
|
] = cfg.trl.ref_model_mixup_alpha
|
||||||
|
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||||
|
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||||
|
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||||
|
return grpo_args_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_trainer_kwargs(cls, cfg):
|
||||||
|
trainer_kwargs = {}
|
||||||
|
if cfg.trl and cfg.trl.reward_funcs:
|
||||||
|
reward_funcs = []
|
||||||
|
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||||
|
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||||
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
|
trainer_kwargs[
|
||||||
|
"reward_processing_classes"
|
||||||
|
] = cfg.trl.reward_processing_classes
|
||||||
|
return trainer_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
|
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_blocklist_args_kwargs(cls):
|
||||||
|
return ["dataset_num_proc"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||||
|
"""
|
||||||
|
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||||
|
or a HF hub path to the reward model.
|
||||||
|
Raises:
|
||||||
|
ValueError: If the reward function does not accept at least two arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||||
|
or a path to a reward model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# use importlib to dynamically load the reward function from the module
|
||||||
|
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||||
|
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||||
|
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||||
|
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# the user has passed a string (ideally indicating the path of a reward model)
|
||||||
|
LOG.info(
|
||||||
|
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
15
src/axolotl/core/trainers/grpo/args.py
Normal file
15
src/axolotl/core/trainers/grpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Axolotl Specific Training Args
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from trl import GRPOConfig
|
||||||
|
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||||
|
"""
|
||||||
|
Axolotl GRPO Config for GRPO training
|
||||||
|
"""
|
||||||
14
src/axolotl/core/trainers/grpo/trainer.py
Normal file
14
src/axolotl/core/trainers/grpo/trainer.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Axolotl GRPO trainer
|
||||||
|
"""
|
||||||
|
from trl import GRPOTrainer
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base GRPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|
||||||
"""
|
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
|
|
||||||
|
|
||||||
https://github.com/HazyResearch/lolcats/
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
|
||||||
|
|
||||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
|
||||||
# nano setup.py
|
|
||||||
|
|
||||||
# Build the CUDA kernel
|
|
||||||
python setup.py install
|
|
||||||
```
|
|
||||||
|
|
||||||
Step 1:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.lolcats.LinearizePlugin
|
|
||||||
|
|
||||||
linearize: true
|
|
||||||
```
|
|
||||||
|
|
||||||
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
|
|
||||||
|
|
||||||
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|
||||||
|
|
||||||
# with optional config below but this requires patching axolotl
|
|
||||||
# to allow this config to work with lora
|
|
||||||
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
|
||||||
```
|
|
||||||
|
|
||||||
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
|
|
||||||
|
|
||||||
Step 3: Run inference on the finetuned model
|
|
||||||
|
|
||||||
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
"""
|
|
||||||
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
|
|
||||||
|
|
||||||
Low-rank Linear Conversion via Attention Transfer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
|
||||||
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
|
|
||||||
DistillAttentionXentMSETrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.lolcats")
|
|
||||||
|
|
||||||
|
|
||||||
class LinearizePlugin(BasePlugin):
|
|
||||||
"""
|
|
||||||
Plugin for lolcats integration with Axolotl.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Register the Linear Llama model with transformers
|
|
||||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
|
||||||
register_linear_llama,
|
|
||||||
)
|
|
||||||
|
|
||||||
register_linear_llama()
|
|
||||||
|
|
||||||
def get_input_args(self):
|
|
||||||
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
|
||||||
|
|
||||||
def get_trainer_cls(self, cfg):
|
|
||||||
# defualt to XentMSE
|
|
||||||
# TODO: add check to allow MSE_linear
|
|
||||||
if cfg.linearize:
|
|
||||||
return DistillAttentionXentMSETrainer
|
|
||||||
|
|
||||||
return None
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
"""
|
|
||||||
Module for handling linear attention input arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMapKwargs(BaseModel):
|
|
||||||
"""Args for feature map"""
|
|
||||||
|
|
||||||
eps: float
|
|
||||||
mlp: Optional[None] = None
|
|
||||||
fullspace: bool
|
|
||||||
|
|
||||||
|
|
||||||
class LearnedKernelKwargs(BaseModel):
|
|
||||||
"""Args for learned kernel"""
|
|
||||||
|
|
||||||
feature_dim: int
|
|
||||||
skip_connection: bool
|
|
||||||
bias: bool
|
|
||||||
zero_init: bool
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionConfig(BaseModel):
|
|
||||||
"""Args for attention config"""
|
|
||||||
|
|
||||||
attention_type: str
|
|
||||||
feature_map: str
|
|
||||||
feature_map_kwargs: FeatureMapKwargs
|
|
||||||
layer_idx: Optional[None] = None
|
|
||||||
learned_kernel: str
|
|
||||||
learned_kernel_kwargs: LearnedKernelKwargs
|
|
||||||
tie_qk_kernels: bool
|
|
||||||
train_qk: bool
|
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionArgs(BaseModel):
|
|
||||||
"""
|
|
||||||
Input args for linear attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
attention_config: AttentionConfig
|
|
||||||
|
|
||||||
linearize: Optional[bool] = False
|
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Linear LLaMA model configuration"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import LlamaConfig
|
|
||||||
|
|
||||||
|
|
||||||
class LinearLlamaConfig(LlamaConfig):
|
|
||||||
"""
|
|
||||||
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
|
|
||||||
It is a modified LlamaConfig that includes additional parameters for linear attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_config (`dict`):
|
|
||||||
Dictionary containing the configuration for linear attention mechanism.
|
|
||||||
Expected contents:
|
|
||||||
`attention_type` (str):
|
|
||||||
The type of attention to convert to.
|
|
||||||
`feature_map` (`str`):
|
|
||||||
The type of feature map to use for linear attention.
|
|
||||||
`feature_map_kwargs` (`dict`):
|
|
||||||
Additional arguments for the feature map.
|
|
||||||
`learned_kernel` (`str`, *optional*):
|
|
||||||
Type of learned kernel to use, if any.
|
|
||||||
`learned_kernel_kwargs` (`dict`, *optional*):
|
|
||||||
Additional arguments for the learned kernel.
|
|
||||||
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
|
|
||||||
Whether to tie query and key kernels.
|
|
||||||
`rotary_config` (`dict`, *optional*):
|
|
||||||
Configuration for rotary embeddings.
|
|
||||||
`train_attention` (`bool`, *optional*, defaults to False):
|
|
||||||
Whether to train attention to match softmax attention.
|
|
||||||
`remove_base_attn` (`bool`, *optional*, defaults to True):
|
|
||||||
Whether to remove base attention after initialization.
|
|
||||||
`mask_value` (`int`, *optional*, defaults to 0):
|
|
||||||
Value to use for masking.
|
|
||||||
`eps` (`float`, *optional*, defaults to 1e-12):
|
|
||||||
Epsilon value for numerical stability.
|
|
||||||
`fp32_attention` (`bool`, *optional*, defaults to False):
|
|
||||||
Whether to use fp32 precision for attention computation.
|
|
||||||
`track_state_grads` (`bool`, *optional*, defaults to False):
|
|
||||||
Whether to track gradients of attention states.
|
|
||||||
|
|
||||||
**kwargs:
|
|
||||||
Additional arguments inherited from LlamaConfig.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_type = "linear_llama"
|
|
||||||
|
|
||||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
# Set auto_map
|
|
||||||
self.auto_map = {
|
|
||||||
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
|
||||||
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
|
||||||
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set default attention config if none provided
|
|
||||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
|
||||||
"""
|
|
||||||
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llama_config (:class:`~transformers.LlamaConfig`):
|
|
||||||
The LlamaConfig to inherit from.
|
|
||||||
|
|
||||||
attention_config (`dict`):
|
|
||||||
Dictionary containing the configuration for linear attention mechanism.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return cls(attention_config=attention_config, **llama_config.to_dict())
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
# Causal linear attention CUDA kernel
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
```bash
|
|
||||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
|
||||||
|
|
||||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
|
||||||
# nano setup.py
|
|
||||||
|
|
||||||
# Build the CUDA kernel
|
|
||||||
python setup.py install
|
|
||||||
```
|
|
||||||
|
|
||||||
Reference: https://github.com/idiap/fast-transformers/
|
|
||||||
|
|
||||||
```bib
|
|
||||||
@inproceedings{katharopoulos_et_al_2020,
|
|
||||||
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
|
|
||||||
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
|
|
||||||
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
|
|
||||||
year = {2020}
|
|
||||||
}
|
|
||||||
|
|
||||||
@article{vyas_et_al_2020,
|
|
||||||
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
|
|
||||||
title={Fast Transformers with Clustered Attention},
|
|
||||||
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
|
|
||||||
year={2020}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
|
||||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
|
||||||
# Apoorv Vyas <avyas@idiap.ch>
|
|
||||||
#
|
|
||||||
from .causal_attention import causal_dot_product
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
//
|
|
||||||
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
|
||||||
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
|
||||||
// Apoorv Vyas <avyas@idiap.ch>
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute a*b^T and save it into out.
|
|
||||||
*
|
|
||||||
* a \in R^A
|
|
||||||
* b \in R^B
|
|
||||||
*/
|
|
||||||
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
|
||||||
for (int i=0; i<A; i++) {
|
|
||||||
float * bi = b;
|
|
||||||
for (int j=0; j<B; j++) {
|
|
||||||
*out += (*a) * (*bi);
|
|
||||||
out++;
|
|
||||||
bi++;
|
|
||||||
}
|
|
||||||
a++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implement a vector matrix product v*m and save it into out.
|
|
||||||
*
|
|
||||||
* v \in R^A
|
|
||||||
* m \in R^{AxB}
|
|
||||||
*/
|
|
||||||
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
|
||||||
// TODO: Consider removing the zeroing part and assuming out already
|
|
||||||
// contains 0s
|
|
||||||
for (int i=0; i<B; i++) {
|
|
||||||
out[i] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i=0; i<A; i++) {
|
|
||||||
float *oi = out;
|
|
||||||
for (int j=0; j<B; j++) {
|
|
||||||
*oi += (*v) * (*m);
|
|
||||||
oi++;
|
|
||||||
m++;
|
|
||||||
}
|
|
||||||
v++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implement a vector transposed-matrix product and save it into out.
|
|
||||||
*
|
|
||||||
* v \in R^B
|
|
||||||
* m \in R^{AxB}
|
|
||||||
*/
|
|
||||||
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
|
||||||
for (int i=0; i<A; i++) {
|
|
||||||
float *vi = v;
|
|
||||||
float s = 0;
|
|
||||||
for (int j=0; j<B; j++) {
|
|
||||||
s += (*vi) * (*m);
|
|
||||||
vi++;
|
|
||||||
m++;
|
|
||||||
}
|
|
||||||
// TODO: Should we be aggregating? See the comment on vm_dot.
|
|
||||||
*out = s;
|
|
||||||
out++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the causally masked dot products of queries, keys and values.
|
|
||||||
*
|
|
||||||
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
|
||||||
* computation is done efficiently by changing the order of the dot products.
|
|
||||||
*/
|
|
||||||
void causal_dot_product(
|
|
||||||
const torch::Tensor queries,
|
|
||||||
const torch::Tensor keys,
|
|
||||||
const torch::Tensor values,
|
|
||||||
torch::Tensor product
|
|
||||||
) {
|
|
||||||
// Extract some shapes
|
|
||||||
int N = queries.size(0);
|
|
||||||
int H = queries.size(1);
|
|
||||||
int L = queries.size(2);
|
|
||||||
int E = queries.size(3);
|
|
||||||
int M = values.size(3);
|
|
||||||
|
|
||||||
// Create accessors for all the arguments
|
|
||||||
auto qa = queries.accessor<float, 4>();
|
|
||||||
auto ka = keys.accessor<float, 4>();
|
|
||||||
auto va = values.accessor<float, 4>();
|
|
||||||
auto pa = product.accessor<float, 4>();
|
|
||||||
|
|
||||||
#pragma omp parallel for collapse(2)
|
|
||||||
for (int n=0; n<N; n++) {
|
|
||||||
for (int h=0; h<H; h++) {
|
|
||||||
auto kv = torch::zeros({E, M}, queries.options());
|
|
||||||
float *kvp = kv.data_ptr<float>();
|
|
||||||
for (int l=0; l<L; l++) {
|
|
||||||
vvt_dot(
|
|
||||||
&ka[n][h][l][0],
|
|
||||||
&va[n][h][l][0],
|
|
||||||
kvp,
|
|
||||||
E,
|
|
||||||
M
|
|
||||||
);
|
|
||||||
vm_dot(
|
|
||||||
&qa[n][h][l][0],
|
|
||||||
kvp,
|
|
||||||
&pa[n][h][l][0],
|
|
||||||
E,
|
|
||||||
M
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the gradients of queries, keys and values given the gradient of the
|
|
||||||
* causal_dot_product output.
|
|
||||||
*
|
|
||||||
* Make sure that everything is computed in O(N D^2) complexity.
|
|
||||||
*/
|
|
||||||
void causal_dot_backward(
|
|
||||||
const torch::Tensor queries,
|
|
||||||
const torch::Tensor keys,
|
|
||||||
const torch::Tensor values,
|
|
||||||
const torch::Tensor grad_out,
|
|
||||||
torch::Tensor grad_queries,
|
|
||||||
torch::Tensor grad_keys,
|
|
||||||
torch::Tensor grad_values
|
|
||||||
) {
|
|
||||||
// Extract some shapes
|
|
||||||
int N = queries.size(0);
|
|
||||||
int H = queries.size(1);
|
|
||||||
int L = queries.size(2);
|
|
||||||
int E = queries.size(3);
|
|
||||||
int M = values.size(3);
|
|
||||||
|
|
||||||
// Create accessors for all the arguments
|
|
||||||
auto qa = queries.accessor<float, 4>();
|
|
||||||
auto ka = keys.accessor<float, 4>();
|
|
||||||
auto va = values.accessor<float, 4>();
|
|
||||||
auto ga = grad_out.accessor<float, 4>();
|
|
||||||
auto gqa = grad_queries.accessor<float, 4>();
|
|
||||||
auto gka = grad_keys.accessor<float, 4>();
|
|
||||||
auto gva = grad_values.accessor<float, 4>();
|
|
||||||
|
|
||||||
#pragma omp parallel for collapse(2)
|
|
||||||
for (int n=0; n<N; n++) {
|
|
||||||
for (int h=0; h<H; h++) {
|
|
||||||
auto kv = torch::zeros({E, M}, queries.options());
|
|
||||||
float *kvp = kv.data_ptr<float>();
|
|
||||||
|
|
||||||
// Compute the gradient wrt the queries
|
|
||||||
for (int l=0; l<L; l++) {
|
|
||||||
vvt_dot(
|
|
||||||
&ka[n][h][l][0],
|
|
||||||
&va[n][h][l][0],
|
|
||||||
kvp,
|
|
||||||
E,
|
|
||||||
M
|
|
||||||
);
|
|
||||||
vmt_dot(
|
|
||||||
&ga[n][h][l][0],
|
|
||||||
kvp,
|
|
||||||
&gqa[n][h][l][0],
|
|
||||||
E,
|
|
||||||
M
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the gradient wrt the keys and values
|
|
||||||
kv.zero_();
|
|
||||||
for (int l=L-1; l>=0; l--) {
|
|
||||||
vvt_dot(
|
|
||||||
&qa[n][h][l][0],
|
|
||||||
&ga[n][h][l][0],
|
|
||||||
kvp,
|
|
||||||
E,
|
|
||||||
M
|
|
||||||
);
|
|
||||||
vmt_dot(
|
|
||||||
&va[n][h][l][0],
|
|
||||||
kvp,
|
|
||||||
&gka[n][h][l][0],
|
|
||||||
E,
|
|
||||||
M
|
|
||||||
);
|
|
||||||
vm_dot(
|
|
||||||
&ka[n][h][l][0],
|
|
||||||
kvp,
|
|
||||||
&gva[n][h][l][0],
|
|
||||||
E,
|
|
||||||
M
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"causal_dot_product",
|
|
||||||
&causal_dot_product,
|
|
||||||
"Compute the weighted sum of values but attending only to previous "
|
|
||||||
"values."
|
|
||||||
);
|
|
||||||
m.def(
|
|
||||||
"causal_dot_backward",
|
|
||||||
&causal_dot_backward,
|
|
||||||
"Compute the gradient of queries, keys and values given the gradient "
|
|
||||||
"of causal_dot_product."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
|
||||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
|
||||||
# Apoorv Vyas <avyas@idiap.ch>
|
|
||||||
#
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
|
||||||
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
|
||||||
except ImportError as e:
|
|
||||||
print(e)
|
|
||||||
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
|
||||||
|
|
||||||
|
|
||||||
class CausalDotProduct(torch.autograd.Function):
|
|
||||||
"""Compute the weighted sum of values but attending only to previous
|
|
||||||
values."""
|
|
||||||
|
|
||||||
dot = {
|
|
||||||
# "cpu": causal_dot_product_cpu,
|
|
||||||
"cuda": causal_dot_product_cuda
|
|
||||||
}
|
|
||||||
dot_backward = {
|
|
||||||
# "cpu": causal_dot_backward_cpu,
|
|
||||||
"cuda": causal_dot_backward_cuda
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, Q, K, V):
|
|
||||||
# Save the inputs for the gradient computation
|
|
||||||
ctx.save_for_backward(Q, K, V)
|
|
||||||
|
|
||||||
# Create the output tensor
|
|
||||||
device = Q.device
|
|
||||||
N, H, L, _ = Q.shape
|
|
||||||
_, _, _, M = V.shape
|
|
||||||
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
|
||||||
|
|
||||||
# Actually perform the dot product
|
|
||||||
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
|
||||||
# breakpoint()
|
|
||||||
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
|
||||||
|
|
||||||
return product
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_out):
|
|
||||||
# Extract the saved tensors
|
|
||||||
Q, K, V = ctx.saved_tensors
|
|
||||||
|
|
||||||
# Allocate memory for the gradients
|
|
||||||
grad_Q = torch.zeros_like(Q)
|
|
||||||
grad_K = torch.zeros_like(K)
|
|
||||||
grad_V = torch.zeros_like(V)
|
|
||||||
|
|
||||||
# Actually compute the gradients
|
|
||||||
CausalDotProduct.dot_backward[Q.device.type](
|
|
||||||
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
|
|
||||||
)
|
|
||||||
|
|
||||||
return grad_Q, grad_K, grad_V
|
|
||||||
|
|
||||||
|
|
||||||
# Alias the autograd functions to python style snake case naming
|
|
||||||
causal_dot_product = CausalDotProduct.apply
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,65 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
|
||||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
|
||||||
# Apoorv Vyas <avyas@idiap.ch>
|
|
||||||
#
|
|
||||||
|
|
||||||
import subprocess # nosec
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from setuptools import setup
|
|
||||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_arch_torch():
|
|
||||||
arch = torch.cuda.get_arch_list()[-1]
|
|
||||||
print(f"Found arch: {arch} from existing torch installation")
|
|
||||||
return arch
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_bare_metal_version(cuda_dir):
|
|
||||||
raw_output = subprocess.check_output(
|
|
||||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
|
|
||||||
)
|
|
||||||
output = raw_output.split()
|
|
||||||
release_idx = output.index("release") + 1
|
|
||||||
release = output[release_idx].split(".")
|
|
||||||
bare_metal_major = release[0]
|
|
||||||
bare_metal_minor = release[1][0]
|
|
||||||
return raw_output, bare_metal_major, bare_metal_minor
|
|
||||||
|
|
||||||
|
|
||||||
def append_nvcc_threads(nvcc_extra_args):
|
|
||||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
|
||||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
|
||||||
return nvcc_extra_args + ["--threads", "4"]
|
|
||||||
return nvcc_extra_args
|
|
||||||
|
|
||||||
|
|
||||||
arch = get_last_arch_torch()
|
|
||||||
sm_num = arch[-2:]
|
|
||||||
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
|
|
||||||
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
|
||||||
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
|
||||||
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
|
||||||
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="causal_attention_cuda_cpp",
|
|
||||||
ext_modules=[
|
|
||||||
CUDAExtension(
|
|
||||||
"causal_attention_cuda",
|
|
||||||
[
|
|
||||||
# 'causal_attention.cpp',
|
|
||||||
"causal_attention_cuda.cu",
|
|
||||||
],
|
|
||||||
extra_compile_args={
|
|
||||||
"cxx": ["-O3"],
|
|
||||||
"nvcc": append_nvcc_threads(
|
|
||||||
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
cmdclass={"build_ext": BuildExtension},
|
|
||||||
)
|
|
||||||
@@ -1,856 +0,0 @@
|
|||||||
"""
|
|
||||||
Linear attention classes
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
from typing import Any, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
|
|
||||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
|
||||||
try:
|
|
||||||
from csrc import causal_dot_product as fast_causal_dot_product
|
|
||||||
except ImportError:
|
|
||||||
fast_causal_dot_product = None
|
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
|
||||||
|
|
||||||
# -------------------
|
|
||||||
# Attention functions
|
|
||||||
# -------------------
|
|
||||||
|
|
||||||
|
|
||||||
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Causal linear attention dot product
|
|
||||||
- If available, use CUDA kernel from fast-transformers
|
|
||||||
"""
|
|
||||||
if fast_causal_dot_product is None:
|
|
||||||
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
|
|
||||||
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
|
|
||||||
return fast_causal_dot_product(q, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
def linear_attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
fp32_attention: bool = False,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""
|
|
||||||
Compute linear attention with CUDA kernel implementation from fast-transformers
|
|
||||||
- https://github.com/idiap/fast-transformers
|
|
||||||
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
|
||||||
v is shape (b, h, l, head_dim)
|
|
||||||
"""
|
|
||||||
dtype = q.dtype
|
|
||||||
# Causal mask already applied
|
|
||||||
y = causal_dot_product(
|
|
||||||
q.contiguous().to(dtype=torch.float32),
|
|
||||||
k.contiguous().to(dtype=torch.float32),
|
|
||||||
v.contiguous().to(dtype=torch.float32),
|
|
||||||
)
|
|
||||||
if fp32_attention:
|
|
||||||
y = (
|
|
||||||
y
|
|
||||||
/ (
|
|
||||||
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
|
|
||||||
)[..., None]
|
|
||||||
).to(dtype=dtype)
|
|
||||||
else:
|
|
||||||
y = y.to(dtype=dtype)
|
|
||||||
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
|
||||||
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
|
||||||
return y, None, None
|
|
||||||
|
|
||||||
|
|
||||||
def softmax_attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: Optional[torch.Tensor] = None,
|
|
||||||
causal: bool = True,
|
|
||||||
fp32_attention: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Standard softmax attention; only compute outputs if v is not None
|
|
||||||
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
|
||||||
"""
|
|
||||||
y = None
|
|
||||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
|
|
||||||
if causal: # Apply causal mask
|
|
||||||
m, n = a.shape[-2:]
|
|
||||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
|
||||||
n - m + 1
|
|
||||||
)
|
|
||||||
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
|
||||||
if fp32_attention:
|
|
||||||
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
||||||
else:
|
|
||||||
a = torch.softmax(a, dim=-1)
|
|
||||||
if v is not None:
|
|
||||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
|
||||||
return y, a, None
|
|
||||||
|
|
||||||
|
|
||||||
def quadratic_attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: Optional[torch.Tensor] = None,
|
|
||||||
causal: bool = True,
|
|
||||||
fp32_attention: bool = False,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
|
||||||
-> Use for attention distillation
|
|
||||||
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
|
||||||
"""
|
|
||||||
y = None
|
|
||||||
dtype = q.dtype
|
|
||||||
if fp32_attention:
|
|
||||||
q, k = q.float(), k.float()
|
|
||||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
|
|
||||||
if causal: # Apply causal mask
|
|
||||||
m, n = a.shape[-2:]
|
|
||||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
|
||||||
n - m + 1
|
|
||||||
)
|
|
||||||
a = a.masked_fill(causal_mask, 0)
|
|
||||||
# Normalize to compute attention
|
|
||||||
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
|
||||||
a = a.to(dtype=dtype) if fp32_attention else a
|
|
||||||
if torch.isnan(a).sum() > 0:
|
|
||||||
breakpoint()
|
|
||||||
if v is not None:
|
|
||||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
|
||||||
return y, a, None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------
|
|
||||||
# Attention layer class
|
|
||||||
# ---------------------
|
|
||||||
|
|
||||||
|
|
||||||
class LolcatsLinearAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
LoLCATs attention implementation initialized from a
|
|
||||||
`LlamaAttention` or `MistralAttention` object (base_attn)
|
|
||||||
|
|
||||||
Most of the arguments are directly tied to argparse args
|
|
||||||
- For now we don't support padding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
base_attn: nn.Module, # like LlamaAttention
|
|
||||||
feature_map: str,
|
|
||||||
feature_map_kwargs: dict,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
max_layer_idx: Optional[int] = None,
|
|
||||||
learned_kernel: Optional[str] = None,
|
|
||||||
learned_kernel_kwargs: Optional[dict] = None,
|
|
||||||
tie_qk_kernels: Optional[bool] = False,
|
|
||||||
rotary_config: Optional[dict] = None,
|
|
||||||
train_attention: Optional[bool] = False,
|
|
||||||
remove_base_attn: bool = True,
|
|
||||||
attention_type: Optional[str] = "lolcats_llama",
|
|
||||||
mask_value: int = 0,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
fp32_attention: bool = False,
|
|
||||||
track_state_grads: bool = False,
|
|
||||||
rank: Optional[int] = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.base_config = getattr(base_attn, "config", None)
|
|
||||||
if self.base_config is not None:
|
|
||||||
self.base_config = self.base_config.to_dict()
|
|
||||||
self.attention_type = attention_type
|
|
||||||
self.mask_value = mask_value
|
|
||||||
self.eps = eps
|
|
||||||
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
|
|
||||||
self.max_layer_idx = max_layer_idx
|
|
||||||
self.tie_qk_kernels = tie_qk_kernels
|
|
||||||
self.train_attention = train_attention
|
|
||||||
self.base_inference = False
|
|
||||||
self.fp32_attention = fp32_attention
|
|
||||||
self.track_state_grads = track_state_grads
|
|
||||||
if rank == 0: # multi-gpu
|
|
||||||
if fp32_attention and layer_idx == 0:
|
|
||||||
print(f"-> fp32_attention is {fp32_attention}")
|
|
||||||
if layer_idx == 0 and feature_map_kwargs is not None:
|
|
||||||
for k, v in feature_map_kwargs.items():
|
|
||||||
print(f"-> {k}: {v}")
|
|
||||||
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
|
||||||
for k, v in learned_kernel_kwargs.items():
|
|
||||||
print(f"-> {k}: {v}")
|
|
||||||
|
|
||||||
self.remove_base_attn = remove_base_attn
|
|
||||||
|
|
||||||
self.init_weights_(base_attn, remove_base_attn)
|
|
||||||
self.init_feature_map_(
|
|
||||||
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_feature_map_(
|
|
||||||
self,
|
|
||||||
feature_map: str,
|
|
||||||
feature_map_kwargs: dict,
|
|
||||||
learned_kernel: Optional[str] = None,
|
|
||||||
learned_kernel_kwargs: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize MLP-based feature map
|
|
||||||
"""
|
|
||||||
self.fmap_gqa = False # Turn True if specified below
|
|
||||||
if learned_kernel is not None and learned_kernel_kwargs is not None:
|
|
||||||
# Ensure dict
|
|
||||||
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
|
||||||
learned_kernel_kwargs["num_heads"] = self.num_heads
|
|
||||||
learned_kernel_kwargs["head_dim"] = self.head_dim
|
|
||||||
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
|
|
||||||
learned_kernel_kwargs["device"] = self.q_proj.weight.device
|
|
||||||
# Create MLP
|
|
||||||
mlp_learned_kernel = init_learned_kernel(
|
|
||||||
learned_kernel, **learned_kernel_kwargs
|
|
||||||
)
|
|
||||||
# Add "activation"; see src.models.feature_map.py
|
|
||||||
self.feature_map_q = init_feature_map(
|
|
||||||
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
|
|
||||||
)
|
|
||||||
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
|
||||||
self.feature_map_k = self.feature_map_q
|
|
||||||
else:
|
|
||||||
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
|
||||||
|
|
||||||
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
|
||||||
"""
|
|
||||||
Initialize module layers, weights, positional dependencies, etc.
|
|
||||||
from original softmax attention layer (base_attn)
|
|
||||||
"""
|
|
||||||
# Make other attributes accessible
|
|
||||||
self.attention_dropout = 0 # We don't use dropout
|
|
||||||
self.hidden_size = base_attn.config.hidden_size
|
|
||||||
self.num_heads = base_attn.config.num_attention_heads
|
|
||||||
self.head_dim = base_attn.head_dim
|
|
||||||
self.num_key_value_heads = base_attn.config.num_key_value_heads
|
|
||||||
self.num_key_value_groups = base_attn.num_key_value_groups
|
|
||||||
|
|
||||||
self.q_shape = [self.num_heads, self.head_dim]
|
|
||||||
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
|
||||||
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
|
||||||
|
|
||||||
# Copy original model projection layers
|
|
||||||
self.q_proj = base_attn.q_proj
|
|
||||||
self.k_proj = base_attn.k_proj
|
|
||||||
self.v_proj = base_attn.v_proj
|
|
||||||
self.o_proj = base_attn.o_proj
|
|
||||||
try: # If wanting to use FA2 for ground-truth inference
|
|
||||||
self._flash_attn_uses_top_left_mask = (
|
|
||||||
base_attn._flash_attn_uses_top_left_mask
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if self.remove_base_attn or remove_base_attn:
|
|
||||||
del base_attn # We don't need to keep these around
|
|
||||||
else:
|
|
||||||
self.base_attn = base_attn # For some training runs helpful to just call
|
|
||||||
|
|
||||||
def process_qkv(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
past_key_value: Optional[Any] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compute queries, keys, and values
|
|
||||||
"""
|
|
||||||
b, l, _ = hidden_states.size()
|
|
||||||
q = self.q_proj(hidden_states)
|
|
||||||
k = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
kv_seq_len = k.shape[-2]
|
|
||||||
|
|
||||||
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
|
||||||
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
|
||||||
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
|
||||||
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
|
||||||
|
|
||||||
if (
|
|
||||||
past_key_value is not None
|
|
||||||
): # and k.shape[2] > q.shape[2]: # e.g., when generating
|
|
||||||
past_key_value.window_size = getattr(
|
|
||||||
self, "decode_window_size", None
|
|
||||||
) # self.decode_window_size
|
|
||||||
if isinstance(
|
|
||||||
past_key_value, Cache
|
|
||||||
): # In Transformers v4.36+ this is a DynamicCache object
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(
|
|
||||||
kv_seq_len, self.layer_idx
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
# Apply rotary embeddings
|
|
||||||
if position_embeddings is not None:
|
|
||||||
cos, sin = position_embeddings
|
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
||||||
|
|
||||||
k = repeat_kv(k, self.num_key_value_groups)
|
|
||||||
v = repeat_kv(v, self.num_key_value_groups)
|
|
||||||
return q, k, v, kv_seq_len
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
|
||||||
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
|
||||||
"""
|
|
||||||
b, l, _ = hidden_states.size()
|
|
||||||
q, k, v, kv_seq_len = self.process_qkv(
|
|
||||||
hidden_states, attention_mask, position_embeddings, past_key_value
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.base_inference:
|
|
||||||
with torch.no_grad():
|
|
||||||
# 1. Compute "ground-truth" attention output and weights
|
|
||||||
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
|
||||||
y_true = (
|
|
||||||
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
attn_weights = (None, None)
|
|
||||||
|
|
||||||
elif self.train_attention: # Distilling / learning attentions
|
|
||||||
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
|
||||||
assert (
|
|
||||||
output_attentions is True
|
|
||||||
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
|
|
||||||
with torch.no_grad():
|
|
||||||
# 1. Compute "ground-truth" attention output and weights
|
|
||||||
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
|
||||||
y_true = (
|
|
||||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
|
|
||||||
# 2. Compute "predicted" attention (just weights)
|
|
||||||
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
|
||||||
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
|
||||||
attn_weights = ( # type: ignore
|
|
||||||
(attn_pred, attn_true),
|
|
||||||
(y_pred, _y_true),
|
|
||||||
) # Save both attention weights so we can supervise.
|
|
||||||
|
|
||||||
else: # Finetuning
|
|
||||||
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
|
||||||
# Apply prefill mask
|
|
||||||
if attention_mask is not None and q.shape[2] > 1:
|
|
||||||
if len(attention_mask.shape) == 4:
|
|
||||||
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
|
|
||||||
..., None
|
|
||||||
] # b, 1, k_len, 1
|
|
||||||
else:
|
|
||||||
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
|
||||||
k = k.masked_fill(~lin_attn_mask, 0)
|
|
||||||
|
|
||||||
if past_key_value is not None: # Initialize states
|
|
||||||
if len(past_key_value.kv_states) == self.layer_idx:
|
|
||||||
b, h, _, f = k.shape
|
|
||||||
past_key_value.kv_states.append(
|
|
||||||
torch.zeros(
|
|
||||||
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
|
|
||||||
)
|
|
||||||
)
|
|
||||||
past_key_value.k_states.append(
|
|
||||||
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
|
||||||
)
|
|
||||||
# Generating
|
|
||||||
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
|
||||||
assert use_cache is True
|
|
||||||
kv_state, k_state = past_key_value.update(
|
|
||||||
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
|
|
||||||
)
|
|
||||||
if self.fp32_attention:
|
|
||||||
q = q.float()
|
|
||||||
y_true = (
|
|
||||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
|
|
||||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
|
|
||||||
..., None
|
|
||||||
]
|
|
||||||
).to(dtype=k.dtype)
|
|
||||||
else:
|
|
||||||
y_true = (
|
|
||||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
|
|
||||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
|
||||||
k_state = past_key_value.k_states[self.layer_idx]
|
|
||||||
y_true, _, _ = linear_attention(
|
|
||||||
q, k, v, self.fp32_attention, self.eps
|
|
||||||
) # Ordinarily the states are ignored
|
|
||||||
past_key_value.update(
|
|
||||||
k.detach(),
|
|
||||||
v.detach(),
|
|
||||||
self.layer_idx,
|
|
||||||
accumulate_in_fp32=self.fp32_attention,
|
|
||||||
)
|
|
||||||
# doing some unnecessary recomputation here
|
|
||||||
else:
|
|
||||||
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
|
||||||
|
|
||||||
# Concatenate heads and apply output projection
|
|
||||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return y_true, attn_weights
|
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionState(Cache):
|
|
||||||
"""
|
|
||||||
Handle the KV and K states for linear attention
|
|
||||||
- Adopts HF Transformers `past_key_values` convention
|
|
||||||
- Inherits from `Cache` class
|
|
||||||
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
|
||||||
self._seen_tokens_by_layer: List[int] = []
|
|
||||||
self.kv_states: List[torch.Tensor] = []
|
|
||||||
self.k_states: List[torch.Tensor] = []
|
|
||||||
|
|
||||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
||||||
"""
|
|
||||||
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
|
||||||
"""
|
|
||||||
if layer_idx is None:
|
|
||||||
raise ValueError("Layer index must not be None")
|
|
||||||
|
|
||||||
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
|
||||||
self._seen_tokens_by_layer.append(0)
|
|
||||||
return self._seen_tokens_by_layer[layer_idx]
|
|
||||||
|
|
||||||
def get_max_length(self) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_usable_length(
|
|
||||||
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
|
||||||
) -> int:
|
|
||||||
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
|
||||||
# Cache without size limit -> all cache is usable
|
|
||||||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
|
||||||
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
|
||||||
max_length = self.get_max_length()
|
|
||||||
previous_seq_length = self.get_seq_length(layer_idx)
|
|
||||||
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
|
||||||
return max_length - new_seq_length
|
|
||||||
return previous_seq_length
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
cache_kwargs: Optional[Any] = None,
|
|
||||||
accumulate_in_fp32: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
if layer_idx is None:
|
|
||||||
raise ValueError("Layer index must not be None")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += key_states.shape[-2]
|
|
||||||
dtype = key_states.dtype
|
|
||||||
if accumulate_in_fp32:
|
|
||||||
key_states, value_states = key_states.float(), value_states.float()
|
|
||||||
|
|
||||||
kv_state = torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd", key_states, value_states
|
|
||||||
).detach()
|
|
||||||
k_state = key_states.sum(
|
|
||||||
dim=-2, keepdim=True
|
|
||||||
).detach() # b, h, 1, f; note the 1
|
|
||||||
# Update the cache
|
|
||||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
|
||||||
print(
|
|
||||||
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
|
|
||||||
)
|
|
||||||
self.kv_states.append(kv_state.to(dtype))
|
|
||||||
self.k_states.append(k_state.to(dtype))
|
|
||||||
else:
|
|
||||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
self.kv_states[layer_idx] = kv_state
|
|
||||||
self.k_states[layer_idx] = k_state
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
|
||||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
|
||||||
|
|
||||||
def to_legacy_cache(self):
|
|
||||||
"""Hack, but just return self"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
||||||
"""
|
|
||||||
Reorders the cache for beam search, given the selected beam indices.
|
|
||||||
-> Copied from transformers/src/transformers/cache_utils.py
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Reordering cache not implemented for LinearAttentionState"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------
|
|
||||||
# feature map functions
|
|
||||||
# -------------------
|
|
||||||
|
|
||||||
|
|
||||||
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
|
||||||
"""
|
|
||||||
Initialize feature map final activation for linear attention
|
|
||||||
"""
|
|
||||||
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
|
||||||
"""
|
|
||||||
Initialize feature map final activation for linear attention
|
|
||||||
"""
|
|
||||||
if name == "softmax_dim" and fullspace:
|
|
||||||
return SoftmaxDim(**kwargs)
|
|
||||||
elif name == "softmax_dim" and not fullspace:
|
|
||||||
return SoftmaxDimHalfspace(**kwargs)
|
|
||||||
elif name == "exp_dim" and fullspace:
|
|
||||||
return Exp(**kwargs)
|
|
||||||
elif name == "exp_dim" and not fullspace:
|
|
||||||
return ExpHalfspace(**kwargs)
|
|
||||||
elif name == "pos_elu":
|
|
||||||
return PosELU(**kwargs)
|
|
||||||
elif name == "relu":
|
|
||||||
return ReLU(**kwargs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def init_learned_kernel(name: str, **kwargs):
|
|
||||||
"""
|
|
||||||
Initialize feature map MLP for linear attention
|
|
||||||
"""
|
|
||||||
if name == "untied_head_einsum":
|
|
||||||
return FeatureMapMLP(**kwargs)
|
|
||||||
elif name == "untied_head_adapter":
|
|
||||||
return FeatureMapAdapter(**kwargs)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMap(nn.Module):
|
|
||||||
"""
|
|
||||||
Final 'activation' of feature map. Can probably be combined with
|
|
||||||
`FeatureMapMLP` below
|
|
||||||
|
|
||||||
Full feature map is like f(xW + b)
|
|
||||||
-> This is the `f` part
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
activation_name: str,
|
|
||||||
head_dim_idx: int = -1,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
mlp: Optional[nn.Module] = None,
|
|
||||||
fullspace: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.head_dim_idx = head_dim_idx
|
|
||||||
self.eps = eps
|
|
||||||
self.mlp = mlp if mlp is not None else nn.Identity()
|
|
||||||
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
|
||||||
"""
|
|
||||||
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
|
||||||
"""
|
|
||||||
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
|
||||||
|
|
||||||
def q_map(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Use for inference in case q and k feature maps differ
|
|
||||||
"""
|
|
||||||
return self.forward(*args, **kwargs)
|
|
||||||
|
|
||||||
def k_map(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Use for inference in case q and k feature maps differ
|
|
||||||
"""
|
|
||||||
return self.forward(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
|
||||||
# Feature map activations
|
|
||||||
# -----------------------
|
|
||||||
class FeatureMapAct(nn.Module):
|
|
||||||
"""
|
|
||||||
Base class for feature map activations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, eps: float = 1e-12):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
|
||||||
"""
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PosELU(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return (1 + F.elu(x)).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class ReLU(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
ReLU activation as in https://arxiv.org/abs/2103.13076
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return F.relu(x).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxDim(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return torch.cat(
|
|
||||||
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
|
||||||
).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxDimHalfspace(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class Exp(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
|
||||||
x_min = torch.amin(x, dim=-1, keepdim=True)
|
|
||||||
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
|
||||||
min=self.eps
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExpHalfspace(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
|
||||||
return torch.exp(x - x_max).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------
|
|
||||||
# Feature map MLPs
|
|
||||||
# ----------------
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMapMLP(nn.Module):
|
|
||||||
"""
|
|
||||||
Learnable MLP in feature map.
|
|
||||||
|
|
||||||
Full feature map is like f(xW + b)
|
|
||||||
-> This is the `W` and (optional) `b` part
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_dim: int, # input dim
|
|
||||||
feature_dim: int, # output dim
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
skip_connection: bool = False,
|
|
||||||
bias: bool = False,
|
|
||||||
zero_init: bool = False,
|
|
||||||
normal_init: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.feature_dim = feature_dim
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
self.skip_connection = skip_connection
|
|
||||||
self.bias = bias
|
|
||||||
self.zero_init = zero_init
|
|
||||||
self.normal_init = normal_init
|
|
||||||
self.init_weights_()
|
|
||||||
|
|
||||||
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
|
||||||
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
|
||||||
|
|
||||||
if self.normal_init:
|
|
||||||
with torch.no_grad():
|
|
||||||
nn.init.normal_(self.layer)
|
|
||||||
|
|
||||||
if self.skip_connection:
|
|
||||||
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
|
||||||
assert self.head_dim == self.feature_dim, assertion_fail
|
|
||||||
|
|
||||||
def init_weights_(self):
|
|
||||||
"""
|
|
||||||
Initialize (W)eights and (b)iases
|
|
||||||
"""
|
|
||||||
self.layer = nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
(self.num_heads, self.head_dim, self.feature_dim),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.layer)
|
|
||||||
|
|
||||||
if self.bias:
|
|
||||||
self.bias = nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
(1, self.num_heads, 1, 1), # self.feature_dim),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.bias)
|
|
||||||
else:
|
|
||||||
self.bias = 0.0 # hack
|
|
||||||
|
|
||||||
def zero_init_with_skip_(self):
|
|
||||||
"""
|
|
||||||
Initialize weights to zero matrix if skip connection
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
nn.init.zeros_(self.layer)
|
|
||||||
|
|
||||||
def zero_init_(self):
|
|
||||||
"""
|
|
||||||
Initialize weights to identity matrix if no skip connection
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(self.layer.shape[0]):
|
|
||||||
try:
|
|
||||||
nn.init.eye_(self.layer[i])
|
|
||||||
except RuntimeError:
|
|
||||||
with torch.no_grad():
|
|
||||||
dtype = self.layer[i].dtype
|
|
||||||
weight = torch.eye(
|
|
||||||
*self.layer[i].shape,
|
|
||||||
requires_grad=self.layer[i].requires_grad,
|
|
||||||
device=self.layer[i].device,
|
|
||||||
)
|
|
||||||
self.layer[i] = weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
|
||||||
"""
|
|
||||||
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
|
||||||
return x + _x if self.skip_connection else _x
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMapAdapter(FeatureMapMLP):
|
|
||||||
"""
|
|
||||||
Learnable Feature map with bottleneck adapter
|
|
||||||
as in https://arxiv.org/abs/1902.00751
|
|
||||||
|
|
||||||
We don't use but could be fun to try
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_dim: int, *args, **kwargs):
|
|
||||||
kwargs["skip_connection"] = True
|
|
||||||
kwargs["bias"] = True
|
|
||||||
kwargs["zero_init"] = True
|
|
||||||
self.hidden_dim = hidden_dim
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def init_weights_(self):
|
|
||||||
"""
|
|
||||||
Initialize (W)eights and (b)iases
|
|
||||||
"""
|
|
||||||
kwargs = {"dtype": self.dtype, "device": self.device}
|
|
||||||
self.layer0 = nn.Parameter(
|
|
||||||
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
|
||||||
)
|
|
||||||
self.layer1 = nn.Parameter(
|
|
||||||
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.layer0)
|
|
||||||
nn.init.kaiming_uniform_(self.layer1)
|
|
||||||
|
|
||||||
self.bias0 = nn.Parameter(
|
|
||||||
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
|
||||||
)
|
|
||||||
self.bias1 = nn.Parameter(
|
|
||||||
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.bias0)
|
|
||||||
nn.init.kaiming_uniform_(self.bias1)
|
|
||||||
|
|
||||||
def zero_init_with_skip_(self):
|
|
||||||
with torch.no_grad():
|
|
||||||
nn.init.zeros_(self.layer0)
|
|
||||||
nn.init.zeros_(self.layer1)
|
|
||||||
nn.init.zeros_(self.bias0)
|
|
||||||
nn.init.zeros_(self.bias1)
|
|
||||||
|
|
||||||
def zero_init_(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
|
||||||
-> Down-project, apply nonlinearity, up-project; add skip connection
|
|
||||||
"""
|
|
||||||
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
|
||||||
_x = F.relu(_x)
|
|
||||||
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
|
||||||
return x + _x if self.skip_connection else _x
|
|
||||||
@@ -1,460 +0,0 @@
|
|||||||
"""
|
|
||||||
Subquadratic attention combining sliding window and linear attentions
|
|
||||||
- Using "standard" sliding windows
|
|
||||||
- Didactically computes outputs with n^2 attention weights for now
|
|
||||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
|
||||||
|
|
||||||
For each layer:
|
|
||||||
- We first compute (softmax) attention over sliding windows
|
|
||||||
- We then compute standard linear attention to "fill in" the earlier parts
|
|
||||||
- We combine to model the entire sequence
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
|
|
||||||
from .linear_attention import (
|
|
||||||
LinearAttentionState,
|
|
||||||
LolcatsLinearAttention,
|
|
||||||
softmax_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------
|
|
||||||
# Sliding window helpers
|
|
||||||
# ----------------------
|
|
||||||
def get_masks(
|
|
||||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Return masks for softmax and linear attention terms
|
|
||||||
-> 1 is include, 0 is ignore
|
|
||||||
"""
|
|
||||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
|
||||||
k_len - q_len
|
|
||||||
)
|
|
||||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
|
||||||
k_len - q_len - window_size
|
|
||||||
)
|
|
||||||
window_mask = causal_mask - linear_mask
|
|
||||||
# Return softmax mask (window), linear attention mask
|
|
||||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
|
||||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def hybrid_attention_quadratic(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
f_q: torch.Tensor,
|
|
||||||
f_k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
window_factor: torch.Tensor,
|
|
||||||
linear_factor: torch.Tensor,
|
|
||||||
window_size: int,
|
|
||||||
kv_state: Optional[torch.Tensor] = None,
|
|
||||||
k_state: Optional[torch.Tensor] = None,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
mask_value: float = -1e8,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Hybrid attention combining sliding window and linear attentions
|
|
||||||
"""
|
|
||||||
|
|
||||||
mask_window, mask_linear = get_masks(
|
|
||||||
window_size, q.shape[-2], k.shape[-2], q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Sliding window (softmax attention)
|
|
||||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
|
||||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
|
||||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 2. Under window (linear attention)
|
|
||||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
|
||||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
|
||||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 3. Combine
|
|
||||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
|
||||||
# Allow outputs to also depend on prior kv_state and k_state
|
|
||||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
|
||||||
if (
|
|
||||||
kv_state is not None and k_state is not None
|
|
||||||
): # Combine with prior kv_state and k_state
|
|
||||||
y += linear_factor * torch.einsum(
|
|
||||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln += (
|
|
||||||
linear_factor
|
|
||||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
|
||||||
)
|
|
||||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
return y, a # attention weights only for the last chunk
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------
|
|
||||||
# Attention layer class
|
|
||||||
# ---------------------
|
|
||||||
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
|
||||||
"""
|
|
||||||
Lolcats attention combining sliding window and linear attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
window_size: int = 64,
|
|
||||||
decode_window_size: Optional[int] = None,
|
|
||||||
affine_attention_factors: bool = False,
|
|
||||||
init_window_factor: float = 0,
|
|
||||||
train_window_factor: bool = True,
|
|
||||||
state_grad_enabled: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.window_size = window_size
|
|
||||||
self.decode_window_size = (
|
|
||||||
decode_window_size if decode_window_size is not None else window_size
|
|
||||||
)
|
|
||||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
|
|
||||||
# Determine how we compute attentions
|
|
||||||
self.quadratic_attention = hybrid_attention_quadratic
|
|
||||||
self.attention_type = kwargs[
|
|
||||||
"attention_type"
|
|
||||||
] # 'hedgehog_long_llama_window_sw'
|
|
||||||
# Learnable factor for combining attentions
|
|
||||||
self.affine_attention_factors = affine_attention_factors
|
|
||||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
|
||||||
if train_window_factor:
|
|
||||||
self.window_factors = nn.Parameter(
|
|
||||||
init_window_factor
|
|
||||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_buffer(
|
|
||||||
"window_factors",
|
|
||||||
init_window_factor
|
|
||||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
|
||||||
)
|
|
||||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
|
||||||
self.base_inference = False
|
|
||||||
self.state_grad_enabled = state_grad_enabled
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Forward pass with the option to compute attention weights multiple ways
|
|
||||||
if self.train_attention is True
|
|
||||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
|
||||||
"""
|
|
||||||
b, l, _ = hidden_states.size()
|
|
||||||
q, k, v, kv_seq_len = self.process_qkv(
|
|
||||||
hidden_states, attention_mask, position_ids, past_key_value
|
|
||||||
)
|
|
||||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
|
||||||
k
|
|
||||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
|
||||||
|
|
||||||
if self.train_attention:
|
|
||||||
# 1. Compute "ground-truth" attention output and weights
|
|
||||||
with torch.no_grad():
|
|
||||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
|
||||||
y_true = (
|
|
||||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
|
|
||||||
# 2. Compute "predicted" attention outputs
|
|
||||||
# compute attn weights under sliding window
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
y_pred, a_pred = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
)
|
|
||||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
|
||||||
else:
|
|
||||||
attn_weights = None
|
|
||||||
# attention_mask = None # For now this is always True
|
|
||||||
if past_key_value is None: # Regular training
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
y_true, a_pred = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
)
|
|
||||||
attn_weights = a_pred
|
|
||||||
else:
|
|
||||||
past_key_value.window_size = self.decode_window_size
|
|
||||||
if (
|
|
||||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
|
||||||
): # Generating
|
|
||||||
assert use_cache is True
|
|
||||||
_kv = past_key_value.update_for_decoding(
|
|
||||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
|
||||||
)
|
|
||||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
|
||||||
|
|
||||||
# Sliding window + linear attention decode
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Softmax attention terms
|
|
||||||
a_sm = torch.einsum(
|
|
||||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
|
||||||
) * (k.shape[-1] ** -0.5)
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Combine with linear attention terms
|
|
||||||
y_true = torch.einsum(
|
|
||||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
|
||||||
) + linear_factors * torch.einsum(
|
|
||||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln = (
|
|
||||||
linear_factors
|
|
||||||
* torch.einsum(
|
|
||||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
|
||||||
)[..., None]
|
|
||||||
)
|
|
||||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
|
|
||||||
else: # Stateful training
|
|
||||||
try:
|
|
||||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
|
||||||
k_state = past_key_value.k_states[self.layer_idx]
|
|
||||||
except IndexError:
|
|
||||||
kv_state, k_state = None, None
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
y_true, _ = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
kv_state=kv_state,
|
|
||||||
k_state=k_state,
|
|
||||||
)
|
|
||||||
# Save and update KV cache and states
|
|
||||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
|
||||||
# fmap_key_states=f_k.detach(),
|
|
||||||
# accumulate_in_fp32=True)
|
|
||||||
past_key_value.update(
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
self.layer_idx,
|
|
||||||
fmap_key_states=f_k,
|
|
||||||
accumulate_in_fp32=True,
|
|
||||||
)
|
|
||||||
# Concatenate heads and apply output projection
|
|
||||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
return y_true, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
|
||||||
"""
|
|
||||||
Class for `past_key_values`
|
|
||||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
|
||||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int = 64) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
|
||||||
self._seen_tokens_by_layer: List[int] = []
|
|
||||||
self.kv_states: List[torch.Tensor] = []
|
|
||||||
self.k_states: List[torch.Tensor] = []
|
|
||||||
|
|
||||||
# Account for sliding windows
|
|
||||||
self.decode_kv_states: List[torch.Tensor] = []
|
|
||||||
self.decode_k_states: List[torch.Tensor] = []
|
|
||||||
self.k_cache: List[torch.Tensor] = []
|
|
||||||
self.v_cache: List[torch.Tensor] = []
|
|
||||||
self.window_size = window_size
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
cache_kwargs: Optional[Any] = None,
|
|
||||||
accumulate_in_fp32: bool = False,
|
|
||||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
|
||||||
grad_enabled: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Update KV, K states; and KV cache during training
|
|
||||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
|
||||||
up to sliding window terms
|
|
||||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
|
||||||
up to end of sequence
|
|
||||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
|
||||||
"""
|
|
||||||
if fmap_key_states is None:
|
|
||||||
raise ValueError("fmap_key_states must not be None")
|
|
||||||
|
|
||||||
if layer_idx is None:
|
|
||||||
raise ValueError("Layer index must not be None")
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(grad_enabled):
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += key_states.shape[-2]
|
|
||||||
|
|
||||||
dtype = key_states.dtype
|
|
||||||
if accumulate_in_fp32:
|
|
||||||
# key_states = key_states.float()
|
|
||||||
fmap_key_states = fmap_key_states.float()
|
|
||||||
value_states = value_states.float()
|
|
||||||
|
|
||||||
# Decoding KV state (KV terms up to last window_size)
|
|
||||||
decode_kv_state = torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd",
|
|
||||||
fmap_key_states[:, :, : -self.window_size],
|
|
||||||
value_states[:, :, : -self.window_size],
|
|
||||||
)
|
|
||||||
# KV state
|
|
||||||
kv_state = decode_kv_state + torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd",
|
|
||||||
fmap_key_states[:, :, -self.window_size :],
|
|
||||||
value_states[:, :, -self.window_size :],
|
|
||||||
)
|
|
||||||
# shape is b, h, 1, f; note the 1
|
|
||||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
|
||||||
dim=-2, keepdim=True
|
|
||||||
)
|
|
||||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
|
||||||
dim=-2, keepdim=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the cache
|
|
||||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
|
||||||
self.kv_states.append(kv_state.to(dtype))
|
|
||||||
self.k_states.append(k_state.to(dtype))
|
|
||||||
|
|
||||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
|
||||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
|
||||||
|
|
||||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
|
||||||
self.v_cache.append(
|
|
||||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
|
||||||
)
|
|
||||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
|
||||||
else:
|
|
||||||
# Update kv and k states recurrently
|
|
||||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
self.kv_states[layer_idx] = kv_state
|
|
||||||
self.k_states[layer_idx] = k_state
|
|
||||||
|
|
||||||
decode_kv_state = (
|
|
||||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
|
||||||
+ decode_kv_state
|
|
||||||
).to(dtype)
|
|
||||||
decode_k_state = (
|
|
||||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
|
||||||
).to(dtype)
|
|
||||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
|
||||||
self.decode_k_states[layer_idx] = decode_k_state
|
|
||||||
|
|
||||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
|
||||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
|
||||||
|
|
||||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
|
||||||
|
|
||||||
def update_for_decoding(
|
|
||||||
self,
|
|
||||||
keys: torch.Tensor,
|
|
||||||
values: torch.Tensor,
|
|
||||||
layer_idx: int,
|
|
||||||
feature_map_k: Callable,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the decoding KV and K states, and KV cache, during decodeing
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
k_cache = self.k_cache[layer_idx]
|
|
||||||
v_cache = self.v_cache[layer_idx]
|
|
||||||
|
|
||||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
|
||||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
|
||||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
|
||||||
else:
|
|
||||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
|
||||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
|
||||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
|
||||||
# else:
|
|
||||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
|
||||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
|
||||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
|
||||||
v_state = v_cache[:, :, :1, :]
|
|
||||||
kv_state = torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
|
||||||
).to(
|
|
||||||
dtype
|
|
||||||
) # b, h, f, d
|
|
||||||
self.decode_kv_states[layer_idx] += kv_state
|
|
||||||
self.decode_k_states[layer_idx] += k_state
|
|
||||||
|
|
||||||
self.k_cache[layer_idx] = torch.cat(
|
|
||||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
|
||||||
)
|
|
||||||
self.v_cache[layer_idx] = torch.cat(
|
|
||||||
[v_cache[:, :, 1:, :], values], dim=-2
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += keys.shape[-2]
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
|
||||||
return (
|
|
||||||
self.k_cache[layer_idx],
|
|
||||||
self.v_cache[layer_idx],
|
|
||||||
self.decode_kv_states[layer_idx],
|
|
||||||
self.decode_k_states[layer_idx],
|
|
||||||
)
|
|
||||||
@@ -1,685 +0,0 @@
|
|||||||
"""
|
|
||||||
Subquadratic attention combining sliding window and linear attentions
|
|
||||||
- Using "standard" sliding windows
|
|
||||||
- Didactically computes outputs with n^2 attention weights for now
|
|
||||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
|
||||||
|
|
||||||
For each layer:
|
|
||||||
- We first compute (softmax) attention over sliding windows
|
|
||||||
- We then compute standard linear attention to "fill in" the earlier parts
|
|
||||||
- We combine to model the entire sequence
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
_flash_attention_forward = None # Transformers v4.36
|
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
||||||
|
|
||||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
|
||||||
from .linear_attention import (
|
|
||||||
LinearAttentionState,
|
|
||||||
LolcatsLinearAttention,
|
|
||||||
causal_dot_product,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------
|
|
||||||
# Sliding window helpers
|
|
||||||
# ----------------------
|
|
||||||
def get_masks(
|
|
||||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Return masks for softmax and linear attention terms
|
|
||||||
-> 1 is include, 0 is ignore
|
|
||||||
"""
|
|
||||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
|
||||||
max(k_len - q_len, 0)
|
|
||||||
)
|
|
||||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
|
||||||
max(k_len - q_len, 0) - window_size
|
|
||||||
)
|
|
||||||
window_mask = causal_mask - linear_mask
|
|
||||||
# Return softmax mask (window), linear attention mask
|
|
||||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
|
||||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def hybrid_attention_quadratic(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
f_q: torch.Tensor,
|
|
||||||
f_k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
window_factor: torch.Tensor,
|
|
||||||
linear_factor: torch.Tensor,
|
|
||||||
window_size: int,
|
|
||||||
kv_state: Optional[torch.Tensor] = None,
|
|
||||||
k_state: Optional[torch.Tensor] = None,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
mask_value: float = -1e8,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Hybrid attention combining sliding window and linear attentions
|
|
||||||
"""
|
|
||||||
|
|
||||||
mask_window, mask_linear = get_masks(
|
|
||||||
window_size, q.shape[-2], k.shape[-2], q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Sliding window (softmax attention)
|
|
||||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
|
||||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
|
||||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 2. Under window (linear attention)
|
|
||||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
|
||||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
|
||||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 3. Combine
|
|
||||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
|
||||||
# Allow outputs to also depend on prior kv_state and k_state
|
|
||||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
|
||||||
if (
|
|
||||||
kv_state is not None and k_state is not None
|
|
||||||
): # Combine with prior kv_state and k_state
|
|
||||||
y += linear_factor * torch.einsum(
|
|
||||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln += (
|
|
||||||
linear_factor
|
|
||||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
|
||||||
)
|
|
||||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
return y, a # attention weights only for the last chunk
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
|
||||||
# Hybrid window attention linear
|
|
||||||
# ------------------------------
|
|
||||||
def under_window_linear_attention(
|
|
||||||
f_q: torch.Tensor,
|
|
||||||
f_k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
window_size: int,
|
|
||||||
linear_factor: torch.Tensor,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
):
|
|
||||||
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
|
||||||
dtype = f_q.dtype
|
|
||||||
w = window_size
|
|
||||||
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
|
||||||
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
|
||||||
qkv = linear_factor * causal_dot_product(
|
|
||||||
f_q.contiguous().to(dtype=torch.float32),
|
|
||||||
f_k.contiguous().to(dtype=torch.float32),
|
|
||||||
v.contiguous().to(dtype=torch.float32),
|
|
||||||
).to(dtype=dtype)
|
|
||||||
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
|
||||||
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
|
||||||
sum_qk[sum_qk == 0] += eps
|
|
||||||
return qkv, sum_qk
|
|
||||||
|
|
||||||
|
|
||||||
def sliding_window_softmax_attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
window_size: int,
|
|
||||||
window_factor: torch.Tensor,
|
|
||||||
mask_value: float = -1e8,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Compute sliding window softmax attention without materializing
|
|
||||||
O(seq_len^2) attention weights
|
|
||||||
"""
|
|
||||||
d = q.shape[-1]
|
|
||||||
# Compute windows for keys
|
|
||||||
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
|
||||||
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
|
||||||
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
|
||||||
|
|
||||||
# Compute windowed_softmax(qk); causal in its construction
|
|
||||||
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
|
|
||||||
a_sm[a_sm == 0] = -torch.finfo(
|
|
||||||
q.dtype
|
|
||||||
).max # heuristic for zeroing out padding above
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
|
|
||||||
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
|
||||||
|
|
||||||
|
|
||||||
def hybrid_attention_linear(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
f_q: torch.Tensor,
|
|
||||||
f_k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
window_factor: Optional[torch.Tensor] = None,
|
|
||||||
linear_factor: Optional[torch.Tensor] = None,
|
|
||||||
window_size: int = 64,
|
|
||||||
kv_state: Optional[torch.Tensor] = None,
|
|
||||||
k_state: Optional[torch.Tensor] = None,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
mask_value: float = -1e8,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Alternative hybrid attention combining sliding window and linear attentions
|
|
||||||
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
|
||||||
"""
|
|
||||||
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
|
||||||
if window_factor is None:
|
|
||||||
raise ValueError("window_factor must be provided")
|
|
||||||
|
|
||||||
if linear_factor is None:
|
|
||||||
raise ValueError("linear_factor must be provided")
|
|
||||||
|
|
||||||
# 1. Sliding window (softmax attention)
|
|
||||||
with torch.no_grad():
|
|
||||||
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
|
|
||||||
q, k, v, window_size, window_factor, mask_value
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Under window (linear attention)
|
|
||||||
qkv_ln, sum_qk_ln = under_window_linear_attention(
|
|
||||||
f_q, f_k, v, window_size, linear_factor, eps
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Combine
|
|
||||||
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
|
||||||
return y, None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------
|
|
||||||
# Attention layer class
|
|
||||||
# ---------------------
|
|
||||||
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
|
||||||
"""
|
|
||||||
Lolcats attention combining sliding window and linear attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
window_size: int = 64,
|
|
||||||
decode_window_size: Optional[int] = None,
|
|
||||||
affine_attention_factors: bool = False,
|
|
||||||
init_window_factor: float = 0,
|
|
||||||
train_window_factor: bool = True,
|
|
||||||
state_grad_enabled: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.window_size = window_size
|
|
||||||
self.decode_window_size = (
|
|
||||||
decode_window_size if decode_window_size is not None else window_size
|
|
||||||
)
|
|
||||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
# Determine how we compute attentions
|
|
||||||
self.linear_attention = hybrid_attention_linear
|
|
||||||
self.attention_type = "lolcats_llama_window_sw"
|
|
||||||
# Learnable factor for combining attentions
|
|
||||||
self.affine_attention_factors = affine_attention_factors
|
|
||||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
|
||||||
if train_window_factor:
|
|
||||||
self.window_factors = nn.Parameter(
|
|
||||||
init_window_factor
|
|
||||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_buffer(
|
|
||||||
"window_factors",
|
|
||||||
init_window_factor
|
|
||||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
|
||||||
)
|
|
||||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
|
||||||
self.base_inference = False
|
|
||||||
self.state_grad_enabled = state_grad_enabled
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Forward pass with the option to compute attention weights multiple ways
|
|
||||||
if self.train_attention is True
|
|
||||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
|
||||||
"""
|
|
||||||
b, l, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if self.train_attention and self.base_inference:
|
|
||||||
with torch.no_grad():
|
|
||||||
_y_true = flash_attention_2(
|
|
||||||
self, # self.base_attn,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=None,
|
|
||||||
output_attentions=False,
|
|
||||||
use_cache=False,
|
|
||||||
)[0]
|
|
||||||
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
|
||||||
y_true = _y_true.reshape(b, l, -1).contiguous()
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
# layer_io = (hidden_states, _y_true) # hack
|
|
||||||
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
|
||||||
return y_true, layer_io, None
|
|
||||||
|
|
||||||
else:
|
|
||||||
q, k, v, kv_seq_len = self.process_qkv(
|
|
||||||
hidden_states, attention_mask, position_ids, past_key_value
|
|
||||||
)
|
|
||||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
|
||||||
k
|
|
||||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
|
||||||
|
|
||||||
attn_weights = None
|
|
||||||
# attention_mask = None # For now this is always True
|
|
||||||
if past_key_value is None: # Regular training
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
y_true, a_pred = self.linear_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
)
|
|
||||||
attn_weights = a_pred
|
|
||||||
else:
|
|
||||||
past_key_value.window_size = self.decode_window_size
|
|
||||||
if (
|
|
||||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
|
||||||
): # Generating
|
|
||||||
assert use_cache is True
|
|
||||||
_kv = past_key_value.update_for_decoding(
|
|
||||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
|
||||||
)
|
|
||||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
|
||||||
|
|
||||||
# Sliding window + linear attention decode
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Softmax attention terms
|
|
||||||
a_sm = torch.einsum(
|
|
||||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
|
||||||
) * (k.shape[-1] ** -0.5)
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Combine with linear attention terms
|
|
||||||
y_true = torch.einsum(
|
|
||||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
|
||||||
) + linear_factors * torch.einsum(
|
|
||||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln = (
|
|
||||||
linear_factors
|
|
||||||
* torch.einsum(
|
|
||||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
|
||||||
)[..., None]
|
|
||||||
)
|
|
||||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
|
|
||||||
else: # Stateful training
|
|
||||||
try:
|
|
||||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
|
||||||
k_state = past_key_value.k_states[self.layer_idx]
|
|
||||||
except IndexError:
|
|
||||||
kv_state, k_state = None, None
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
y_true, _ = self.linear_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
kv_state=kv_state,
|
|
||||||
k_state=k_state,
|
|
||||||
)
|
|
||||||
# Save and update KV cache and states
|
|
||||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
|
||||||
# fmap_key_states=f_k.detach(),
|
|
||||||
# accumulate_in_fp32=True)
|
|
||||||
past_key_value.update(
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
self.layer_idx,
|
|
||||||
fmap_key_states=f_k,
|
|
||||||
accumulate_in_fp32=True,
|
|
||||||
)
|
|
||||||
# Concatenate heads and apply output projection
|
|
||||||
_y_true = y_true.transpose(1, 2).contiguous()
|
|
||||||
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
|
||||||
|
|
||||||
if self.train_attention:
|
|
||||||
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
|
||||||
return y_true, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
|
||||||
"""
|
|
||||||
Class for `past_key_values`
|
|
||||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
|
||||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int = 64) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
|
||||||
self._seen_tokens_by_layer: List[int] = []
|
|
||||||
self.kv_states: List[torch.Tensor] = []
|
|
||||||
self.k_states: List[torch.Tensor] = []
|
|
||||||
|
|
||||||
# Account for sliding windows
|
|
||||||
self.decode_kv_states: List[torch.Tensor] = []
|
|
||||||
self.decode_k_states: List[torch.Tensor] = []
|
|
||||||
self.k_cache: List[torch.Tensor] = []
|
|
||||||
self.v_cache: List[torch.Tensor] = []
|
|
||||||
self.window_size = window_size
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
cache_kwargs: Optional[Any] = None,
|
|
||||||
accumulate_in_fp32: bool = False,
|
|
||||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
|
||||||
grad_enabled: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Update KV, K states; and KV cache during training
|
|
||||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
|
||||||
up to sliding window terms
|
|
||||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
|
||||||
up to end of sequence
|
|
||||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
|
||||||
"""
|
|
||||||
if fmap_key_states is None:
|
|
||||||
raise ValueError("fmap_key_states must not be None")
|
|
||||||
|
|
||||||
if layer_idx is None:
|
|
||||||
raise ValueError("Layer index must not be None")
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(grad_enabled):
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += key_states.shape[-2]
|
|
||||||
|
|
||||||
dtype = key_states.dtype
|
|
||||||
if accumulate_in_fp32:
|
|
||||||
# key_states = key_states.float()
|
|
||||||
fmap_key_states = fmap_key_states.float()
|
|
||||||
value_states = value_states.float()
|
|
||||||
|
|
||||||
# Decoding KV state (KV terms up to last window_size)
|
|
||||||
decode_kv_state = torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd",
|
|
||||||
fmap_key_states[:, :, : -self.window_size],
|
|
||||||
value_states[:, :, : -self.window_size],
|
|
||||||
)
|
|
||||||
# KV state
|
|
||||||
kv_state = decode_kv_state + torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd",
|
|
||||||
fmap_key_states[:, :, -self.window_size :],
|
|
||||||
value_states[:, :, -self.window_size :],
|
|
||||||
)
|
|
||||||
# shape is b, h, 1, f; note the 1
|
|
||||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
|
||||||
dim=-2, keepdim=True
|
|
||||||
)
|
|
||||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
|
||||||
dim=-2, keepdim=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the cache
|
|
||||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
|
||||||
self.kv_states.append(kv_state.to(dtype))
|
|
||||||
self.k_states.append(k_state.to(dtype))
|
|
||||||
|
|
||||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
|
||||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
|
||||||
|
|
||||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
|
||||||
self.v_cache.append(
|
|
||||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
|
||||||
)
|
|
||||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
|
||||||
else:
|
|
||||||
# Update kv and k states recurrently
|
|
||||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
self.kv_states[layer_idx] = kv_state
|
|
||||||
self.k_states[layer_idx] = k_state
|
|
||||||
|
|
||||||
decode_kv_state = (
|
|
||||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
|
||||||
+ decode_kv_state
|
|
||||||
).to(dtype)
|
|
||||||
decode_k_state = (
|
|
||||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
|
||||||
).to(dtype)
|
|
||||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
|
||||||
self.decode_k_states[layer_idx] = decode_k_state
|
|
||||||
|
|
||||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
|
||||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
|
||||||
|
|
||||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
|
||||||
|
|
||||||
def update_for_decoding(
|
|
||||||
self,
|
|
||||||
keys: torch.Tensor,
|
|
||||||
values: torch.Tensor,
|
|
||||||
layer_idx: int,
|
|
||||||
feature_map_k: Callable,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the decoding KV and K states, and KV cache, during decodeing
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
k_cache = self.k_cache[layer_idx]
|
|
||||||
v_cache = self.v_cache[layer_idx]
|
|
||||||
|
|
||||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
|
||||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
|
||||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
|
||||||
else:
|
|
||||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
|
||||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
|
||||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
|
||||||
# else:
|
|
||||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
|
||||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
|
||||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
|
||||||
v_state = v_cache[:, :, :1, :]
|
|
||||||
kv_state = torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
|
||||||
).to(
|
|
||||||
dtype
|
|
||||||
) # b, h, f, d
|
|
||||||
self.decode_kv_states[layer_idx] += kv_state
|
|
||||||
self.decode_k_states[layer_idx] += k_state
|
|
||||||
|
|
||||||
self.k_cache[layer_idx] = torch.cat(
|
|
||||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
|
||||||
)
|
|
||||||
self.v_cache[layer_idx] = torch.cat(
|
|
||||||
[v_cache[:, :, 1:, :], values], dim=-2
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += keys.shape[-2]
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
|
||||||
return (
|
|
||||||
self.k_cache[layer_idx],
|
|
||||||
self.v_cache[layer_idx],
|
|
||||||
self.decode_kv_states[layer_idx],
|
|
||||||
self.decode_k_states[layer_idx],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------
|
|
||||||
# Flash Attention 2
|
|
||||||
# -----------------
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attention_2(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Wrapper for LlamaFlashAttention2
|
|
||||||
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
|
||||||
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
|
||||||
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
|
||||||
"""
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Flash attention requires the input to have the shape
|
|
||||||
# batch_size x seq_length x head_dim x hidden_dim
|
|
||||||
# therefore we just need to keep the original shape
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
try: # As in Transformers v4.36
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
||||||
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
except Exception: # As in Transformers v4.39
|
|
||||||
cos, sin = self.rotary_emb(key_states, position_ids)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin
|
|
||||||
)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, cache_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
|
||||||
# to be able to avoid many of these transpose/reshape/view.
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
||||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
||||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
||||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
|
||||||
|
|
||||||
input_dtype = query_states.dtype
|
|
||||||
if input_dtype == torch.float32:
|
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
target_dtype = torch.get_autocast_gpu_dtype()
|
|
||||||
# Handle the case where the model is quantized
|
|
||||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
|
||||||
else:
|
|
||||||
target_dtype = self.q_proj.weight.dtype
|
|
||||||
|
|
||||||
LOG.debug(
|
|
||||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
||||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
||||||
f" {target_dtype}."
|
|
||||||
)
|
|
||||||
|
|
||||||
query_states = query_states.to(target_dtype)
|
|
||||||
key_states = key_states.to(target_dtype)
|
|
||||||
value_states = value_states.to(target_dtype)
|
|
||||||
|
|
||||||
if getattr(self, "_flash_attention_forward", False):
|
|
||||||
attn_output = self._flash_attention_forward(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attention_mask,
|
|
||||||
q_len,
|
|
||||||
dropout=dropout_rate,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = _flash_attention_forward(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attention_mask,
|
|
||||||
q_len,
|
|
||||||
dropout=0, # dropout_rate,
|
|
||||||
sliding_window=getattr(self, "sliding_window", None),
|
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
return attn_output, past_key_value
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
LoLCATs attention combining sliding window and linear attentions
|
|
||||||
- Using standard sliding window arrangement
|
|
||||||
- Training over long sequences with fixed memory with recurrent view
|
|
||||||
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
|
||||||
|
|
||||||
For each layer:
|
|
||||||
- We first compute (softmax) attention over sliding windows
|
|
||||||
- We then compute standard linear attention to "fill in" the earlier parts
|
|
||||||
- We combine to model the entire sequence
|
|
||||||
"""
|
|
||||||
from .linear_window_attention_sw import hybrid_attention_quadratic
|
|
||||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
|
||||||
|
|
||||||
|
|
||||||
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
|
||||||
"""
|
|
||||||
Lolcats attention combining sliding window and linear attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, remove_base_attn=True, **kwargs):
|
|
||||||
# keep self.base_attn for Flash Attention inference
|
|
||||||
super().__init__(remove_base_attn=True, **kwargs)
|
|
||||||
self.quadratic_attention = hybrid_attention_quadratic
|
|
||||||
@@ -1,466 +0,0 @@
|
|||||||
"""
|
|
||||||
Subquadratic attention combining sliding window and linear attentions
|
|
||||||
- Using the TK "terracing" arrangement
|
|
||||||
|
|
||||||
For each layer:
|
|
||||||
- We first compute (softmax) attention over sliding windows
|
|
||||||
- We then compute standard linear attention to "fill in" the earlier parts
|
|
||||||
- We combine to model the entire sequence
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
|
|
||||||
from .linear_attention import (
|
|
||||||
LinearAttentionState,
|
|
||||||
LolcatsLinearAttention,
|
|
||||||
softmax_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------
|
|
||||||
# Sliding window helpers
|
|
||||||
# ----------------------
|
|
||||||
def get_masks(
|
|
||||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Return masks for softmax and linear attention terms
|
|
||||||
-> 1 is include, 0 is ignore
|
|
||||||
"""
|
|
||||||
win_len = window_size
|
|
||||||
m = math.ceil(max(q_len, k_len) / window_size)
|
|
||||||
# Creates an n x n mask where n = window_size^2
|
|
||||||
mask = torch.block_diag(
|
|
||||||
*[
|
|
||||||
torch.ones(
|
|
||||||
(win_len, win_len),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
* m
|
|
||||||
)
|
|
||||||
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
|
|
||||||
if mask.shape[0] > q_len:
|
|
||||||
mask = mask[-q_len:]
|
|
||||||
if mask.shape[1] > k_len:
|
|
||||||
mask = mask[:, -k_len:]
|
|
||||||
# Return softmax mask (window), linear attention mask
|
|
||||||
mask = mask[None, None, ...] # b, h, q_len, k_len
|
|
||||||
return (
|
|
||||||
torch.tril(mask).to(device=device, dtype=torch.int),
|
|
||||||
torch.tril(1 - mask).to(device=device, dtype=torch.int),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def hybrid_attention_quadratic(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
f_q: torch.Tensor,
|
|
||||||
f_k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
window_factor: torch.Tensor,
|
|
||||||
linear_factor: torch.Tensor,
|
|
||||||
window_size: int,
|
|
||||||
kv_state: Optional[torch.Tensor] = None,
|
|
||||||
k_state: Optional[torch.Tensor] = None,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
mask_value: float = -1e8,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Hybrid attention combining sliding window and linear attentions
|
|
||||||
"""
|
|
||||||
|
|
||||||
mask_window, mask_linear = get_masks(
|
|
||||||
window_size, q.shape[-2], k.shape[-2], q.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Sliding window (softmax attention)
|
|
||||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
|
||||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
|
||||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 2. Under window (linear attention)
|
|
||||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
|
||||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
|
||||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 3. Combine
|
|
||||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
|
||||||
# Allow outputs to also depend on prior kv_state and k_state
|
|
||||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
|
||||||
if (
|
|
||||||
kv_state is not None and k_state is not None
|
|
||||||
): # Combine with prior kv_state and k_state
|
|
||||||
y += linear_factor * torch.einsum(
|
|
||||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln += (
|
|
||||||
linear_factor
|
|
||||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
|
||||||
)
|
|
||||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
return y, a # attention weights only for the last chunk
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------
|
|
||||||
# Attention layer class
|
|
||||||
# ---------------------
|
|
||||||
class LolcatsTKWindowAttention(LolcatsLinearAttention):
|
|
||||||
"""
|
|
||||||
Lolcats attention combining sliding window and linear attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
window_size: int = 64,
|
|
||||||
decode_window_size: Optional[int] = None,
|
|
||||||
affine_attention_factors: bool = False,
|
|
||||||
init_window_factor: float = 0,
|
|
||||||
train_window_factor: bool = True,
|
|
||||||
state_grad_enabled: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.window_size = window_size
|
|
||||||
self.decode_window_size = (
|
|
||||||
decode_window_size if decode_window_size is not None else window_size
|
|
||||||
)
|
|
||||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
|
|
||||||
# Determine how we compute attentions
|
|
||||||
self.quadratic_attention = hybrid_attention_quadratic
|
|
||||||
self.attention_type = kwargs[
|
|
||||||
"attention_type"
|
|
||||||
] # 'hedgehog_long_llama_window_tk'
|
|
||||||
# Learnable factor for combining attentions
|
|
||||||
self.affine_attention_factors = affine_attention_factors
|
|
||||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
|
||||||
if train_window_factor:
|
|
||||||
self.window_factors = nn.Parameter(
|
|
||||||
init_window_factor
|
|
||||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_buffer(
|
|
||||||
"window_factors",
|
|
||||||
init_window_factor
|
|
||||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
|
||||||
)
|
|
||||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
|
||||||
self.base_inference = False
|
|
||||||
self.state_grad_enabled = state_grad_enabled
|
|
||||||
self.window_factor = self.window_factors # legacy naming support
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Forward pass with the option to compute attention weights multiple ways
|
|
||||||
if self.train_attention is True
|
|
||||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
|
||||||
"""
|
|
||||||
b, l, _ = hidden_states.size()
|
|
||||||
q, k, v, kv_seq_len = self.process_qkv(
|
|
||||||
hidden_states, attention_mask, position_ids, past_key_value
|
|
||||||
)
|
|
||||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
|
||||||
k
|
|
||||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
|
||||||
|
|
||||||
if self.train_attention:
|
|
||||||
# 1. Compute "ground-truth" attention output and weights
|
|
||||||
with torch.no_grad():
|
|
||||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
|
||||||
y_true = (
|
|
||||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
|
|
||||||
# 2. Compute "predicted" attention outputs
|
|
||||||
# compute attn weights under sliding window
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
y_pred, a_pred = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
)
|
|
||||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
|
||||||
else:
|
|
||||||
attn_weights = None
|
|
||||||
# attention_mask = None # For now this is always True
|
|
||||||
if past_key_value is None: # Regular training
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
y_true, a_pred = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
)
|
|
||||||
attn_weights = a_pred
|
|
||||||
else:
|
|
||||||
past_key_value.window_size = self.decode_window_size
|
|
||||||
if (
|
|
||||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
|
||||||
): # Generating
|
|
||||||
assert use_cache is True
|
|
||||||
_kv = past_key_value.update_for_decoding(
|
|
||||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
|
||||||
)
|
|
||||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
|
||||||
|
|
||||||
# Sliding window + linear attention decode
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Softmax attention terms
|
|
||||||
a_sm = torch.einsum(
|
|
||||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
|
||||||
) * (k.shape[-1] ** -0.5)
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Combine with linear attention terms
|
|
||||||
y_true = torch.einsum(
|
|
||||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
|
||||||
) + linear_factors * torch.einsum(
|
|
||||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln = (
|
|
||||||
linear_factors
|
|
||||||
* torch.einsum(
|
|
||||||
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
|
|
||||||
)[..., None]
|
|
||||||
)
|
|
||||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
|
|
||||||
else: # Stateful training
|
|
||||||
try:
|
|
||||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
|
||||||
k_state = past_key_value.k_states[self.layer_idx]
|
|
||||||
except IndexError:
|
|
||||||
kv_state, k_state = None, None
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
y_true, _ = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
kv_state=kv_state,
|
|
||||||
k_state=k_state,
|
|
||||||
)
|
|
||||||
# Save and update KV cache and states
|
|
||||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
|
||||||
# fmap_key_states=f_k.detach(),
|
|
||||||
# accumulate_in_fp32=True)
|
|
||||||
past_key_value.update(
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
self.layer_idx,
|
|
||||||
fmap_key_states=f_k,
|
|
||||||
accumulate_in_fp32=True,
|
|
||||||
)
|
|
||||||
# Concatenate heads and apply output projection
|
|
||||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
return y_true, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionTKWindowCache(LinearAttentionState):
|
|
||||||
"""
|
|
||||||
Class for `past_key_values`
|
|
||||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
|
||||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int = 64) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
|
||||||
self._seen_tokens_by_layer: List[int] = []
|
|
||||||
self.kv_states: List[torch.Tensor] = []
|
|
||||||
self.k_states: List[torch.Tensor] = []
|
|
||||||
|
|
||||||
# Account for sliding windows
|
|
||||||
self.decode_kv_states: List[torch.Tensor] = []
|
|
||||||
self.decode_k_states: List[torch.Tensor] = []
|
|
||||||
self.k_cache: List[torch.Tensor] = []
|
|
||||||
self.v_cache: List[torch.Tensor] = []
|
|
||||||
self.window_size = window_size
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
cache_kwargs: Optional[Any] = None,
|
|
||||||
accumulate_in_fp32: bool = False,
|
|
||||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
|
||||||
grad_enabled: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Update KV, K states; and KV cache during training
|
|
||||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
|
||||||
up to sliding window terms
|
|
||||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
|
||||||
up to end of sequence
|
|
||||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
|
||||||
"""
|
|
||||||
if fmap_key_states is None:
|
|
||||||
raise ValueError("fmap_key_states should not be None")
|
|
||||||
|
|
||||||
if layer_idx is None:
|
|
||||||
raise ValueError("layer_idx should not be None")
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(grad_enabled):
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += key_states.shape[-2]
|
|
||||||
|
|
||||||
dtype = key_states.dtype
|
|
||||||
if accumulate_in_fp32:
|
|
||||||
# key_states = key_states.float()
|
|
||||||
fmap_key_states = fmap_key_states.float()
|
|
||||||
value_states = value_states.float()
|
|
||||||
|
|
||||||
# Decoding KV state (KV terms up to last window_size)
|
|
||||||
decode_kv_state = torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd",
|
|
||||||
fmap_key_states[:, :, : -self.window_size],
|
|
||||||
value_states[:, :, : -self.window_size],
|
|
||||||
)
|
|
||||||
# KV state
|
|
||||||
kv_state = decode_kv_state + torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd",
|
|
||||||
fmap_key_states[:, :, -self.window_size :],
|
|
||||||
value_states[:, :, -self.window_size :],
|
|
||||||
)
|
|
||||||
# shape is b, h, 1, f; note the 1
|
|
||||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
|
||||||
dim=-2, keepdim=True
|
|
||||||
)
|
|
||||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
|
||||||
dim=-2, keepdim=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the cache
|
|
||||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
|
||||||
self.kv_states.append(kv_state.to(dtype))
|
|
||||||
self.k_states.append(k_state.to(dtype))
|
|
||||||
|
|
||||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
|
||||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
|
||||||
|
|
||||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
|
||||||
self.v_cache.append(
|
|
||||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
|
||||||
)
|
|
||||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
|
||||||
else:
|
|
||||||
# Update kv and k states recurrently
|
|
||||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
|
||||||
dtype
|
|
||||||
)
|
|
||||||
self.kv_states[layer_idx] = kv_state
|
|
||||||
self.k_states[layer_idx] = k_state
|
|
||||||
|
|
||||||
decode_kv_state = (
|
|
||||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
|
||||||
+ decode_kv_state
|
|
||||||
).to(dtype)
|
|
||||||
decode_k_state = (
|
|
||||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
|
||||||
).to(dtype)
|
|
||||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
|
||||||
self.decode_k_states[layer_idx] = decode_k_state
|
|
||||||
|
|
||||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
|
||||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
|
||||||
|
|
||||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
|
||||||
|
|
||||||
def update_for_decoding(
|
|
||||||
self,
|
|
||||||
keys: torch.Tensor,
|
|
||||||
values: torch.Tensor,
|
|
||||||
layer_idx: int,
|
|
||||||
feature_map_k: Callable,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the decoding KV and K states, and KV cache, during decodeing
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
k_cache = self.k_cache[layer_idx]
|
|
||||||
v_cache = self.v_cache[layer_idx]
|
|
||||||
|
|
||||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
|
||||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
|
||||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
|
||||||
else:
|
|
||||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
|
||||||
v_state = v_cache[:, :, :1, :]
|
|
||||||
kv_state = torch.einsum(
|
|
||||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
|
||||||
).to(
|
|
||||||
dtype
|
|
||||||
) # b, h, f, d
|
|
||||||
self.decode_kv_states[layer_idx] += kv_state
|
|
||||||
self.decode_k_states[layer_idx] += k_state
|
|
||||||
|
|
||||||
self.k_cache[layer_idx] = torch.cat(
|
|
||||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
|
||||||
)
|
|
||||||
self.v_cache[layer_idx] = torch.cat(
|
|
||||||
[v_cache[:, :, 1:, :], values], dim=-2
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += keys.shape[-2]
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
|
||||||
return (
|
|
||||||
self.k_cache[layer_idx],
|
|
||||||
self.v_cache[layer_idx],
|
|
||||||
self.decode_kv_states[layer_idx],
|
|
||||||
self.decode_k_states[layer_idx],
|
|
||||||
)
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
"""
|
|
||||||
LoLCATs + ThunderKittens linear attention + sliding window for generation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Callable, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .linear_attention import LinearAttentionState
|
|
||||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
|
||||||
|
|
||||||
LOG.debug("Successfully imported ThunderKittens for TK window attention")
|
|
||||||
except ImportError:
|
|
||||||
LOG.debug("Failed to import ThunderKittens for TK window attention")
|
|
||||||
|
|
||||||
|
|
||||||
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
|
|
||||||
def __init__(self, *args, window_size: int = 64, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.train_attention = False
|
|
||||||
self.base_inference = False
|
|
||||||
self.window_size = 64 # hard-coded support for TK kernel
|
|
||||||
self.decode_window_size = 64
|
|
||||||
|
|
||||||
b, h, l, d = 1, 32, 8192, 128
|
|
||||||
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
|
|
||||||
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
|
|
||||||
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Any] = None, # “legacy” cache approach
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Forward pass with the option to compute attention weights multiple ways
|
|
||||||
if self.train_attention is True
|
|
||||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
|
||||||
"""
|
|
||||||
b, l, _ = hidden_states.size()
|
|
||||||
assert (
|
|
||||||
past_key_value is not None
|
|
||||||
), "past_key_value must be provided for generation"
|
|
||||||
assert (
|
|
||||||
self.train_attention is False
|
|
||||||
), "train_attention is not supported for generation"
|
|
||||||
assert (
|
|
||||||
self.base_inference is False
|
|
||||||
), "base_inference is not supported for generation"
|
|
||||||
assert use_cache is True, "use_cache must be True for generation"
|
|
||||||
past_key_value.window_size = self.decode_window_size
|
|
||||||
q, k, v, kv_seq_len = self.process_qkv(
|
|
||||||
hidden_states, attention_mask, position_ids, past_key_value
|
|
||||||
)
|
|
||||||
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
|
|
||||||
f_q = self.feature_map_q(q)
|
|
||||||
_kv = past_key_value.update_for_decoding(
|
|
||||||
k, v, self.layer_idx, self.feature_map_k
|
|
||||||
)
|
|
||||||
k_cache, v_cache, kv_state, k_state = _kv
|
|
||||||
# Sliding window + linear attention decode
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
|
|
||||||
# Softmax attention terms
|
|
||||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
|
||||||
k.shape[-1] ** -0.5
|
|
||||||
)
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Combine with linear attention terms
|
|
||||||
y_true = torch.einsum(
|
|
||||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
|
||||||
) + linear_factors * torch.einsum(
|
|
||||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln = (
|
|
||||||
linear_factors
|
|
||||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
|
|
||||||
..., None
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
|
|
||||||
else: # Process prefill
|
|
||||||
# Use TK-implemented linear + terrace window attention
|
|
||||||
b, h, l, d = q.shape
|
|
||||||
device = q.device
|
|
||||||
# tk.hedgehog arguments
|
|
||||||
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
|
|
||||||
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
|
|
||||||
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
|
|
||||||
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
|
|
||||||
alphas = (
|
|
||||||
1 - betas
|
|
||||||
if self.affine_attention_factors
|
|
||||||
else torch.ones(betas.shape, dtype=torch.float32, device=device)
|
|
||||||
)
|
|
||||||
q_map = self.feature_map_q.mlp.layer
|
|
||||||
k_map = self.feature_map_k.mlp.layer
|
|
||||||
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
|
|
||||||
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
|
||||||
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
|
|
||||||
# 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’,
|
|
||||||
# f_k[:, :, :-self.window_size],
|
|
||||||
# v[:, :, :-self.window_size]) # b, h, f, d
|
|
||||||
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
|
|
||||||
|
|
||||||
tk_window_hedgehog_attention(
|
|
||||||
q.contiguous(),
|
|
||||||
k.contiguous(),
|
|
||||||
v.contiguous(),
|
|
||||||
self.y_true,
|
|
||||||
self.k_state,
|
|
||||||
self.kv_state,
|
|
||||||
q_map,
|
|
||||||
k_map,
|
|
||||||
alphas,
|
|
||||||
betas,
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_value.update_with_kv(
|
|
||||||
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
# Concatenate heads and apply output projection
|
|
||||||
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
return y_true, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
|
|
||||||
"""
|
|
||||||
Class for `past_key_values`
|
|
||||||
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
|
|
||||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size: int = 64) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
|
||||||
self._seen_tokens_by_layer: List[int] = []
|
|
||||||
self.window_size = window_size
|
|
||||||
|
|
||||||
self.decode_kv_states: List[torch.Tensor] = []
|
|
||||||
self.decode_k_states: List[torch.Tensor] = []
|
|
||||||
self.k_cache: List[torch.Tensor] = []
|
|
||||||
self.v_cache: List[torch.Tensor] = []
|
|
||||||
|
|
||||||
def update_with_kv(
|
|
||||||
self,
|
|
||||||
kv_state: torch.Tensor,
|
|
||||||
k_state: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_idx: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the cache with new KV and K states
|
|
||||||
"""
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += k.shape[2]
|
|
||||||
self._seen_tokens_by_layer.append(k.shape[2])
|
|
||||||
|
|
||||||
# Initialize KV and K states
|
|
||||||
if len(self.decode_k_states) <= layer_idx:
|
|
||||||
self.decode_kv_states.append(kv_state)
|
|
||||||
self.decode_k_states.append(k_state)
|
|
||||||
else: # Update KV and K states
|
|
||||||
self.decode_kv_states[layer_idx] = (
|
|
||||||
self.decode_kv_states[layer_idx] + kv_state
|
|
||||||
)
|
|
||||||
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
|
|
||||||
|
|
||||||
self.k_cache.append(k[:, :, -self.window_size :, :])
|
|
||||||
self.v_cache.append(v[:, :, -self.window_size :, :])
|
|
||||||
|
|
||||||
def update_for_decoding(
|
|
||||||
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the cache for decoding
|
|
||||||
"""
|
|
||||||
k_cache = self.k_cache[layer_idx]
|
|
||||||
v_cache = self.v_cache[layer_idx]
|
|
||||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
|
||||||
v_state = v_cache[:, :, :1, :]
|
|
||||||
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
|
|
||||||
k.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decode_kv_states[layer_idx] += kv_state
|
|
||||||
self.decode_k_states[layer_idx] += k_state
|
|
||||||
|
|
||||||
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
|
|
||||||
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
|
|
||||||
if layer_idx == 0:
|
|
||||||
self._seen_tokens += k.shape[-2]
|
|
||||||
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
|
|
||||||
return (
|
|
||||||
self.k_cache[layer_idx],
|
|
||||||
self.v_cache[layer_idx],
|
|
||||||
self.decode_kv_states[layer_idx],
|
|
||||||
self.decode_k_states[layer_idx],
|
|
||||||
)
|
|
||||||
@@ -1,306 +0,0 @@
|
|||||||
"""
|
|
||||||
LoLCATs attention combining sliding window and linear attentions
|
|
||||||
- Using the TK "terracing" arrangement
|
|
||||||
- Training over long sequences with fixed memory with recurrent view
|
|
||||||
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
|
||||||
|
|
||||||
For each layer:
|
|
||||||
- We first compute (softmax) attention over sliding windows
|
|
||||||
- We then compute standard linear attention to "fill in" the earlier parts
|
|
||||||
- We combine to model the entire sequence
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
_flash_attention_forward = None # Transformers v4.36
|
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
||||||
|
|
||||||
from .linear_attention import softmax_attention
|
|
||||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
|
||||||
|
|
||||||
LOG = logging.getLogger(
|
|
||||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
|
|
||||||
"""
|
|
||||||
Lolcats attention combining sliding window and linear attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, remove_base_attn=True, **kwargs):
|
|
||||||
# keep self.base_attn for Flash Attention inference
|
|
||||||
super().__init__(remove_base_attn=True, **kwargs)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Forward pass with the option to compute attention weights multiple ways
|
|
||||||
if self.train_attention is True
|
|
||||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
|
||||||
"""
|
|
||||||
b, l, _ = hidden_states.size()
|
|
||||||
if self.train_attention and self.base_inference:
|
|
||||||
with torch.no_grad():
|
|
||||||
# LOG.debug(hidden_states.shape)
|
|
||||||
_y_true = flash_attention_2(
|
|
||||||
self, # self.base_attn,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=None,
|
|
||||||
output_attentions=False,
|
|
||||||
# output_hidden_states=False,
|
|
||||||
use_cache=False,
|
|
||||||
)[0]
|
|
||||||
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
|
||||||
y_true = _y_true.reshape(b, l, -1).contiguous()
|
|
||||||
y_true = self.o_proj(y_true)
|
|
||||||
layer_io = (hidden_states, _y_true) # hack
|
|
||||||
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
|
||||||
return y_true, layer_io, None
|
|
||||||
|
|
||||||
q, k, v, kv_seq_len = self.process_qkv(
|
|
||||||
hidden_states, attention_mask, position_ids, past_key_value
|
|
||||||
)
|
|
||||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
|
||||||
|
|
||||||
# attention_mask = None # For now this is always True
|
|
||||||
if past_key_value is None: # Regular training
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
y_pred, a_pred = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
past_key_value.window_size = self.decode_window_size
|
|
||||||
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
|
||||||
assert use_cache is True
|
|
||||||
_kv = past_key_value.update_for_decoding(
|
|
||||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
|
||||||
)
|
|
||||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
|
||||||
|
|
||||||
# Sliding window + linear attention decode
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
|
||||||
k.shape[-1] ** -0.5
|
|
||||||
)
|
|
||||||
# a_sm = torch.softmax(a_sm, dim=-1)
|
|
||||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
|
||||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
|
||||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
y_pred = torch.einsum(
|
|
||||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
|
||||||
) + linear_factors * torch.einsum(
|
|
||||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
|
||||||
)
|
|
||||||
sum_ln = (
|
|
||||||
linear_factors
|
|
||||||
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
|
|
||||||
..., None
|
|
||||||
]
|
|
||||||
)
|
|
||||||
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
|
|
||||||
|
|
||||||
else: # Stateful training
|
|
||||||
if (
|
|
||||||
self.state_grad_enabled
|
|
||||||
and self.layer_idx == 0
|
|
||||||
and position_ids is not None
|
|
||||||
):
|
|
||||||
LOG.debug(
|
|
||||||
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
|
|
||||||
)
|
|
||||||
LOG.debug(
|
|
||||||
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
|
||||||
k_state = past_key_value.k_states[self.layer_idx]
|
|
||||||
except IndexError:
|
|
||||||
kv_state, k_state = None, None
|
|
||||||
window_factors = F.sigmoid(self.window_factors)
|
|
||||||
linear_factors = (
|
|
||||||
1 - window_factors if self.affine_attention_factors else 1
|
|
||||||
)
|
|
||||||
y_pred, a_pred = self.quadratic_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
f_q,
|
|
||||||
f_k,
|
|
||||||
v,
|
|
||||||
window_factors,
|
|
||||||
linear_factors,
|
|
||||||
window_size=self.window_size,
|
|
||||||
kv_state=kv_state,
|
|
||||||
k_state=k_state,
|
|
||||||
)
|
|
||||||
# Save and update KV cache and states
|
|
||||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
|
||||||
# fmap_key_states=f_k.detach(),
|
|
||||||
# accumulate_in_fp32=True)
|
|
||||||
past_key_value.update(
|
|
||||||
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Concatenate heads and apply output projection
|
|
||||||
_y_pred = y_pred.transpose(1, 2).contiguous()
|
|
||||||
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
|
|
||||||
|
|
||||||
if self.train_attention:
|
|
||||||
with torch.no_grad():
|
|
||||||
a_true = softmax_attention(q, k, None, causal=True)[1]
|
|
||||||
attn_weights = (_y_pred, (a_pred, a_true))
|
|
||||||
else:
|
|
||||||
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
|
|
||||||
return y_pred, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------
|
|
||||||
# Flash Attention 2
|
|
||||||
# -----------------
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attention_2(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.LongTensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Wrapper for LlamaFlashAttention2
|
|
||||||
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
|
||||||
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
|
||||||
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
|
||||||
"""
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Flash attention requires the input to have the shape
|
|
||||||
# batch_size x seq_length x head_dim x hidden_dim
|
|
||||||
# therefore we just need to keep the original shape
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
try: # As in Transformers v4.36
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
||||||
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
except Exception: # As in Transformers v4.39
|
|
||||||
cos, sin = self.rotary_emb(key_states, position_ids)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin
|
|
||||||
)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, cache_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
|
||||||
# to be able to avoid many of these transpose/reshape/view.
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
||||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
||||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
||||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
|
||||||
|
|
||||||
input_dtype = query_states.dtype
|
|
||||||
if input_dtype == torch.float32:
|
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
target_dtype = torch.get_autocast_gpu_dtype()
|
|
||||||
# Handle the case where the model is quantized
|
|
||||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
|
||||||
else:
|
|
||||||
target_dtype = self.q_proj.weight.dtype
|
|
||||||
|
|
||||||
LOG.debug(
|
|
||||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
||||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
||||||
f" {target_dtype}."
|
|
||||||
)
|
|
||||||
|
|
||||||
query_states = query_states.to(target_dtype)
|
|
||||||
key_states = key_states.to(target_dtype)
|
|
||||||
value_states = value_states.to(target_dtype)
|
|
||||||
|
|
||||||
if getattr(self, "_flash_attention_forward", False):
|
|
||||||
attn_output = self._flash_attention_forward(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attention_mask,
|
|
||||||
q_len,
|
|
||||||
dropout=dropout_rate,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = _flash_attention_forward(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attention_mask,
|
|
||||||
q_len,
|
|
||||||
dropout=0, # dropout_rate,
|
|
||||||
sliding_window=getattr(self, "sliding_window", None),
|
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
return attn_output, past_key_value
|
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
"""Linear LLaMA model implementation."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaRMSNorm,
|
|
||||||
LlamaRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .configuration_linear_llama import LinearLlamaConfig
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
|
||||||
"""
|
|
||||||
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
|
||||||
super().__init__(config, layer_idx)
|
|
||||||
|
|
||||||
# Replace the attention layer with our custom attention
|
|
||||||
self.self_attn = convert_llama_attention(
|
|
||||||
layer=self, attention_config=config.attention_config
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearLlamaModel(LlamaModel):
|
|
||||||
"""
|
|
||||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: LinearLlamaConfig
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = LinearLlamaConfig
|
|
||||||
base_model_prefix = "linear_llama"
|
|
||||||
|
|
||||||
def __init__(self, config: LinearLlamaConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.padding_idx = config.pad_token_id
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(
|
|
||||||
config.vocab_size, config.hidden_size, self.padding_idx
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
LinearLlamaDecoderLayer(config, layer_idx)
|
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
|
|
||||||
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|
||||||
"""
|
|
||||||
Linear LLaMA model for causal language modeling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = LinearLlamaConfig
|
|
||||||
base_model_prefix = "linear_llama"
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self.model = LinearLlamaModel(config)
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llama(
|
|
||||||
cls,
|
|
||||||
model: LlamaForCausalLM,
|
|
||||||
config: LinearLlamaConfig,
|
|
||||||
train_attention: bool = False,
|
|
||||||
remove_base_attn: bool = True,
|
|
||||||
) -> "LinearLlamaForCausalLM":
|
|
||||||
"""
|
|
||||||
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
|
||||||
"""
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
raise ValueError("Missing config")
|
|
||||||
|
|
||||||
# initialize a new model with config
|
|
||||||
new_model = cls(config=config)
|
|
||||||
|
|
||||||
# remove the default model and lm_head
|
|
||||||
del new_model.model
|
|
||||||
del new_model.lm_head
|
|
||||||
|
|
||||||
# load converted model, lm_head, and vocab_size from llama model
|
|
||||||
new_model.model = convert_attention(
|
|
||||||
model.model,
|
|
||||||
attention_config=config.attention_config,
|
|
||||||
train_attention=train_attention,
|
|
||||||
remove_base_attn=remove_base_attn,
|
|
||||||
)
|
|
||||||
new_model.lm_head = model.lm_head
|
|
||||||
new_model.vocab_size = model.vocab_size
|
|
||||||
|
|
||||||
return new_model
|
|
||||||
|
|
||||||
def toggle_attention(self, train: bool = True):
|
|
||||||
"""
|
|
||||||
Toggle attention to be trainable or not
|
|
||||||
"""
|
|
||||||
|
|
||||||
toggle_attention(self.model, train=train)
|
|
||||||
|
|
||||||
def remove_base_attention(self):
|
|
||||||
"""
|
|
||||||
Remove base attention after distillation
|
|
||||||
"""
|
|
||||||
|
|
||||||
remove_base_attention(self.model)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_attention(
|
|
||||||
model: nn.Module,
|
|
||||||
attention_config: dict,
|
|
||||||
train_attention: bool = False,
|
|
||||||
remove_base_attn: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Call to convert all attention layers
|
|
||||||
"""
|
|
||||||
# Get the layers to convert if provided
|
|
||||||
softmax_attns = attention_config.get("softmax_attentions", [])
|
|
||||||
|
|
||||||
# Get the attention to convert to
|
|
||||||
attention_type = attention_config.get("attention_type")
|
|
||||||
|
|
||||||
if attention_type != "softmax":
|
|
||||||
layers = traverse_layers(model)
|
|
||||||
for layer_idx, layer in enumerate(
|
|
||||||
tqdm(layers, desc="Converting attentions...")
|
|
||||||
):
|
|
||||||
if layer_idx not in softmax_attns:
|
|
||||||
layer.self_attn = convert_llama_attention(
|
|
||||||
layer,
|
|
||||||
attention_config,
|
|
||||||
layers,
|
|
||||||
train_attention,
|
|
||||||
remove_base_attn,
|
|
||||||
)
|
|
||||||
layer.self_attn.converted = True
|
|
||||||
else:
|
|
||||||
# Freeze any preserved softmax attention layers
|
|
||||||
for p in layer.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
else:
|
|
||||||
LOG.info(
|
|
||||||
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
|
||||||
"""
|
|
||||||
Make attentions trainable if train is True
|
|
||||||
-> Set train_attention = False when finetuning
|
|
||||||
"""
|
|
||||||
for layer in traverse_layers(llama_model):
|
|
||||||
layer.self_attn.train_attention = train
|
|
||||||
return llama_model
|
|
||||||
|
|
||||||
|
|
||||||
def remove_base_attention(llama_model: nn.Module):
|
|
||||||
"""
|
|
||||||
Remove teacher attention after distillation (if we keep it)
|
|
||||||
"""
|
|
||||||
for layer in traverse_layers(llama_model):
|
|
||||||
if getattr(layer.self_attn, "base_attn", False):
|
|
||||||
del layer.self_attn.base_attn
|
|
||||||
return llama_model
|
|
||||||
|
|
||||||
|
|
||||||
def traverse_layers(model: nn.Module, verbose: bool = False):
|
|
||||||
"""
|
|
||||||
Return list of model layers
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
layers = model.model.layers
|
|
||||||
if verbose:
|
|
||||||
LOG.info("-> Loading from model.model.layers")
|
|
||||||
except AttributeError as e: # if base model
|
|
||||||
if verbose:
|
|
||||||
LOG.info(e)
|
|
||||||
try:
|
|
||||||
layers = model.layers
|
|
||||||
if verbose:
|
|
||||||
LOG.info("-> Loading from model.layers")
|
|
||||||
except AttributeError as e1: # If we make a PEFT model
|
|
||||||
if verbose:
|
|
||||||
LOG.info(e1)
|
|
||||||
layers = model.base_model.model.model.layers
|
|
||||||
if verbose:
|
|
||||||
LOG.info("-> Loading from model.base_model.model.model.layers")
|
|
||||||
return layers
|
|
||||||
|
|
||||||
|
|
||||||
def convert_llama_attention(
|
|
||||||
layer: nn.Module,
|
|
||||||
attention_config: dict,
|
|
||||||
layers: Optional[list[nn.Module]] = None, # list of layers
|
|
||||||
train_attention: bool = False,
|
|
||||||
remove_base_attn: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Converts a single layer's attention layer as specified by attention_config
|
|
||||||
"""
|
|
||||||
return get_attention(**attention_config)(
|
|
||||||
base_attn=layer.self_attn,
|
|
||||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
|
||||||
max_layer_idx=len(layers) - 1 if layers else None,
|
|
||||||
train_attention=train_attention,
|
|
||||||
remove_base_attn=remove_base_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_attention(attention_type: str, **kwargs):
|
|
||||||
"""
|
|
||||||
Get the linear attention class; either purely linear or linear with sliding window
|
|
||||||
-> 'linear' == 'lolcats_llama'
|
|
||||||
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
|
||||||
"""
|
|
||||||
kwargs["attention_type"] = attention_type
|
|
||||||
|
|
||||||
if attention_type == "lolcats_llama":
|
|
||||||
from .linear_attention import LolcatsLinearAttention
|
|
||||||
|
|
||||||
return partial(LolcatsLinearAttention, **kwargs)
|
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_tk":
|
|
||||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
|
||||||
|
|
||||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_sw":
|
|
||||||
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
|
|
||||||
|
|
||||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
|
||||||
from .linear_window_attention_sw_linear import (
|
|
||||||
LolcatsLinearSlidingWindowAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
|
||||||
|
|
||||||
# Experimental chunked linear attentions below
|
|
||||||
elif attention_type == "lolcats_long_llama_window_tk":
|
|
||||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
|
||||||
|
|
||||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
|
||||||
|
|
||||||
elif attention_type == "lolcats_long_llama_window_sw":
|
|
||||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
|
||||||
|
|
||||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
|
||||||
|
|
||||||
# TK generation build (requires Thunderkittens)
|
|
||||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
|
||||||
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
|
|
||||||
|
|
||||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
|
||||||
"""
|
|
||||||
Determine how we store past keys and values when generating
|
|
||||||
"""
|
|
||||||
if attention_type is None:
|
|
||||||
return past_key_values
|
|
||||||
|
|
||||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
|
||||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
|
||||||
from .linear_window_attention_tk_gen import (
|
|
||||||
LinearAttentionTKWindowGenerationCache,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LinearAttentionTKWindowGenerationCache()
|
|
||||||
|
|
||||||
elif "llama_window_tk" in attention_type:
|
|
||||||
from .linear_window_attention_tk import LinearAttentionTKWindowCache
|
|
||||||
|
|
||||||
return LinearAttentionTKWindowCache()
|
|
||||||
|
|
||||||
elif "llama_window_sw" in attention_type:
|
|
||||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
|
||||||
|
|
||||||
return LinearAttentionSlidingWindowCache()
|
|
||||||
|
|
||||||
elif "llama_window_sw_linear" in attention_type:
|
|
||||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
|
||||||
|
|
||||||
return LinearAttentionSlidingWindowCache()
|
|
||||||
|
|
||||||
# TK generation build (requires Thunderkittens)
|
|
||||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
|
||||||
from .linear_window_attention_tk_gen import (
|
|
||||||
LinearAttentionTKWindowGenerationCache,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LinearAttentionTKWindowGenerationCache()
|
|
||||||
|
|
||||||
elif "softmax" in attention_type:
|
|
||||||
return past_key_values
|
|
||||||
|
|
||||||
else:
|
|
||||||
from .linear_attention import LinearAttentionState
|
|
||||||
|
|
||||||
return LinearAttentionState()
|
|
||||||
|
|
||||||
|
|
||||||
def register_linear_llama():
|
|
||||||
"""
|
|
||||||
Register Linear LLaMA model with the Transformers library.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
|
||||||
|
|
||||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
|
||||||
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
|
||||||
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
|
||||||
|
|
||||||
# registering for auto classes to save files
|
|
||||||
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
|
||||||
LinearLlamaModel.register_for_auto_class("AutoModel")
|
|
||||||
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
"""
|
|
||||||
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
|
|
||||||
|
|
||||||
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from torch import Tensor, nn, tensor
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
|
|
||||||
|
|
||||||
class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Custom trainer class for distilling attentions.
|
|
||||||
- We compute and store the attention outputs and/or weights for each head and layer,
|
|
||||||
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
|
|
||||||
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
mse_factor: float = 1e3,
|
|
||||||
xent_factor: float = 0,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
super().__init__(model=model, **kwargs)
|
|
||||||
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
|
||||||
self.criterion_mse = nn.MSELoss(reduction="mean")
|
|
||||||
self.mse_factor = mse_factor
|
|
||||||
self.xent_factor = xent_factor
|
|
||||||
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
|
|
||||||
|
|
||||||
self.model_accepts_loss_kwargs = False # added to combat explosive loss
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, Tensor],
|
|
||||||
return_outputs=False,
|
|
||||||
num_items_in_batch=None,
|
|
||||||
) -> tuple[Tensor, dict]:
|
|
||||||
"""
|
|
||||||
Attention distillation ("attention transfer")
|
|
||||||
- For each layer and head, get attentions and train to
|
|
||||||
minimize some combo of MSE and cross-entropy loss
|
|
||||||
"""
|
|
||||||
# alias inputs to data
|
|
||||||
data = inputs
|
|
||||||
|
|
||||||
device = model.device
|
|
||||||
|
|
||||||
# Filter out labels
|
|
||||||
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
|
||||||
|
|
||||||
# set num_items_in_batch
|
|
||||||
if self.model_accepts_loss_kwargs:
|
|
||||||
loss_kwargs = {}
|
|
||||||
if num_items_in_batch is not None:
|
|
||||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
|
||||||
inputs = {**inputs, **loss_kwargs}
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
|
||||||
outputs = outputs.get("attentions")
|
|
||||||
|
|
||||||
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
|
||||||
# n_layers x (predicted_attns, true_attns)
|
|
||||||
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
|
||||||
loss_mse = tensor(0.0, device=device)
|
|
||||||
loss_xent = tensor(0.0, device=device)
|
|
||||||
n_layers = 0 # Number of layers to distill
|
|
||||||
softmax_layers = []
|
|
||||||
for layer_idx, attns in enumerate(outputs):
|
|
||||||
if attns is not None:
|
|
||||||
if len(attns) != 2:
|
|
||||||
attns = attns.cpu()
|
|
||||||
else:
|
|
||||||
if self.xent_factor > 0:
|
|
||||||
# Cross-entropy loss
|
|
||||||
a_pred, a_true = attns[0]
|
|
||||||
a_pred = a_pred.clamp(
|
|
||||||
min=1e-12
|
|
||||||
).log() # nn.CrossEntropy assumes unnormalized logits
|
|
||||||
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
|
|
||||||
# Compute mean cross-entropy over all queries
|
|
||||||
a_pred = a_pred.contiguous().view(-1, k_len)
|
|
||||||
a_true = a_true.contiguous().view(-1, k_len)
|
|
||||||
loss_xent += self.criterion_xent(a_pred, a_true)
|
|
||||||
if self.mse_factor > 0:
|
|
||||||
loss_mse += self.criterion_mse(*attns[1])
|
|
||||||
n_layers += 1
|
|
||||||
else:
|
|
||||||
softmax_layers.append(layer_idx)
|
|
||||||
if n_layers > 0:
|
|
||||||
loss_xent = loss_xent / n_layers * self.xent_factor
|
|
||||||
loss_mse = loss_mse / n_layers * self.mse_factor
|
|
||||||
loss = loss_xent + loss_mse
|
|
||||||
|
|
||||||
if "position_ids" in data:
|
|
||||||
outputs = {
|
|
||||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
|
||||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
|
||||||
"input_len": data["position_ids"].shape[1],
|
|
||||||
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
|
|
||||||
"mse_factor": self.mse_factor,
|
|
||||||
"xent_factor": self.xent_factor,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
outputs = {
|
|
||||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
|
||||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
|
||||||
"mse_factor": self.mse_factor,
|
|
||||||
"xent_factor": self.xent_factor,
|
|
||||||
}
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
@@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
|||||||
if len(strategy.split(".")) == 1:
|
if len(strategy.split(".")) == 1:
|
||||||
strategy = strategy + ".default"
|
strategy = strategy + ".default"
|
||||||
load_fn = strategy.split(".")[-1]
|
load_fn = strategy.split(".")[-1]
|
||||||
strategy = ".".join(strategy.split(".")[:-1])
|
if len(strategy.split(".")) > 1:
|
||||||
mod = importlib.import_module(f".{strategy}", module_base)
|
try:
|
||||||
|
importlib.import_module(
|
||||||
|
strategy.split(".")[-2],
|
||||||
|
".".join(strategy.split(".")[:-2]),
|
||||||
|
)
|
||||||
|
module_base = ".".join(strategy.split(".")[:-2])
|
||||||
|
strategy = strategy.split(".")[-2]
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
else:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(strategy, module_base)
|
||||||
func = getattr(mod, load_fn)
|
func = getattr(mod, load_fn)
|
||||||
return func(cfg, **kwargs)
|
return func(cfg, **kwargs)
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||||
@@ -70,7 +70,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||||
|
|||||||
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
DPO prompt strategies passthrough/zero-processing strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def default(
|
||||||
|
cfg, dataset_idx=0, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(
|
||||||
|
sample, tokenizer=None
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -24,6 +24,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
|
|
||||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||||
|
|
||||||
|
from .trl import TrlConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||||
|
|
||||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||||
@@ -33,6 +35,7 @@ class RLType(str, Enum):
|
|||||||
"""RL trainer type configuration subset"""
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
@@ -663,14 +666,20 @@ class AxolotlInputConfig(
|
|||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
mean_resizing_embeddings: Optional[bool] = False
|
mean_resizing_embeddings: Optional[bool] = False
|
||||||
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
|
shrink_embeddings: Optional[bool] = None
|
||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
|
trl: Optional[TrlConfig] = Field(
|
||||||
|
default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda
|
||||||
|
)
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
process_reward_model: Optional[bool] = None
|
process_reward_model: Optional[bool] = None
|
||||||
num_labels: Optional[int] = None
|
num_labels: Optional[int] = None
|
||||||
dpo_use_weighting: Optional[
|
dpo_use_weighting: Optional[
|
||||||
bool
|
bool
|
||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
|
dpo_use_logits_to_keep: Optional[bool] = None
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||||
|
|||||||
32
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
32
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
GRPO specific configuration args
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TrlConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for TRL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
beta: Optional[float] = None
|
||||||
|
max_completion_length: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the completion for RL training"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# GRPO specific args
|
||||||
|
use_vllm: Optional[bool] = False
|
||||||
|
vllm_device: Optional[str] = "auto"
|
||||||
|
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||||
|
vllm_max_model_len: Optional[int] = None
|
||||||
|
vllm_dtype: Optional[str] = "auto"
|
||||||
|
reward_funcs: Optional[List[str]] = None
|
||||||
|
num_generations: Optional[int] = None
|
||||||
|
sync_ref_model: Optional[bool] = False
|
||||||
|
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||||
|
ref_model_sync_steps: Optional[int] = 64
|
||||||
@@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
|
||||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||||
sig = inspect.signature(ds_transform_fn)
|
sig = inspect.signature(ds_transform_fn)
|
||||||
if "tokenizer" in sig.parameters:
|
if "tokenizer" in sig.parameters:
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
@@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
|||||||
data_set = data_set.map(
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_set
|
return data_set
|
||||||
@@ -150,36 +151,45 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
else:
|
else:
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
|
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
elif _cfg.rl == "kto":
|
elif _cfg.rl == "kto":
|
||||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
split_datasets[i] = data_set
|
split_datasets[i] = data_set
|
||||||
|
|
||||||
drop_long = partial(
|
if not cfg.skip_prepare_dataset:
|
||||||
drop_long_rl_seq,
|
drop_long = partial(
|
||||||
rl=_cfg.rl,
|
drop_long_rl_seq,
|
||||||
tokenizer=tokenizer,
|
rl=_cfg.rl,
|
||||||
sequence_len=cfg.sequence_len,
|
tokenizer=tokenizer,
|
||||||
)
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
prior_len = len(split_datasets[i])
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(split_datasets[i])
|
dropped = prior_len - len(split_datasets[i])
|
||||||
if dropped:
|
if dropped:
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} long samples from dataset index {i}"
|
||||||
|
)
|
||||||
|
|
||||||
combined_datasets = concatenate_datasets(split_datasets)
|
combined_datasets = concatenate_datasets(split_datasets)
|
||||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
|||||||
from axolotl.utils.data.shared import load_dataset_w_config
|
from axolotl.utils.data.shared import load_dataset_w_config
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
|
drop_long_seq_in_dataset,
|
||||||
md5,
|
md5,
|
||||||
retry_on_request_exceptions,
|
retry_on_request_exceptions,
|
||||||
)
|
)
|
||||||
@@ -56,7 +57,7 @@ from axolotl.utils.trainer import (
|
|||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
@@ -339,8 +340,11 @@ def load_tokenized_prepared_datasets(
|
|||||||
else:
|
else:
|
||||||
LOG.debug("NOT shuffling merged datasets")
|
LOG.debug("NOT shuffling merged datasets")
|
||||||
|
|
||||||
if cfg.sample_packing and not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset = drop_long_seq_in_dataset(dataset, cfg)
|
||||||
|
|
||||||
|
if cfg.sample_packing:
|
||||||
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""data handling helpers"""
|
"""data handling helpers"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
@@ -6,10 +7,15 @@ import time
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import drop_long_seq
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RetryStrategy(Enum):
|
class RetryStrategy(Enum):
|
||||||
@@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return train_dataset, eval_dataset, dataset
|
return train_dataset, eval_dataset, dataset
|
||||||
|
|
||||||
|
|
||||||
|
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||||
|
if "input_ids" not in dataset.column_names:
|
||||||
|
LOG.warning(
|
||||||
|
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
drop_long = functools.partial(
|
||||||
|
drop_long_seq,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||||
|
LOG.debug(f"min_input_len: {min_input_len}")
|
||||||
|
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||||
|
LOG.debug(f"max_input_len: {max_input_len}")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
prior_len = len(dataset)
|
||||||
|
except TypeError:
|
||||||
|
# handle iterable datasets case
|
||||||
|
prior_len = None
|
||||||
|
|
||||||
|
filter_map_kwargs = {}
|
||||||
|
if not isinstance(dataset, IterableDataset):
|
||||||
|
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||||
|
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||||
|
|
||||||
|
drop_long_kwargs = {}
|
||||||
|
if filter_map_kwargs:
|
||||||
|
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||||
|
|
||||||
|
dataset = dataset.filter(
|
||||||
|
drop_long,
|
||||||
|
batched=True,
|
||||||
|
**filter_map_kwargs,
|
||||||
|
**drop_long_kwargs,
|
||||||
|
)
|
||||||
|
if prior_len:
|
||||||
|
dropped = prior_len - len(dataset)
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|||||||
@@ -1053,9 +1053,12 @@ class ModelLoader:
|
|||||||
if self.cfg.resize_token_embeddings_to_32x
|
if self.cfg.resize_token_embeddings_to_32x
|
||||||
else len(self.tokenizer)
|
else len(self.tokenizer)
|
||||||
)
|
)
|
||||||
if (
|
if hasattr(self.model, "get_input_embeddings") and (
|
||||||
hasattr(self.model, "get_input_embeddings")
|
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
or (
|
||||||
|
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||||
|
and self.cfg.shrink_embeddings
|
||||||
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
|
|||||||
@@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
|
|||||||
else:
|
else:
|
||||||
input_ids = dataset.data.column("input_ids")
|
input_ids = dataset.data.column("input_ids")
|
||||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||||
return lengths
|
|
||||||
return lengths
|
return lengths
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
|
|
||||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
"""
|
"""
|
||||||
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
|
|
||||||
# Edge case: if input_ids is empty
|
# Edge case: if input_ids is empty
|
||||||
@@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_long = partial(
|
|
||||||
drop_long_seq,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
min_sequence_len=cfg.min_sample_len or 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
@@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||||
|
|
||||||
filter_map_kwargs = {}
|
|
||||||
if not isinstance(train_dataset, IterableDataset):
|
|
||||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
|
||||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
|
||||||
|
|
||||||
try:
|
|
||||||
prior_len = len(train_dataset)
|
|
||||||
except TypeError:
|
|
||||||
# handle iterable datasets case
|
|
||||||
prior_len = None
|
|
||||||
drop_long_kwargs = {}
|
|
||||||
if filter_map_kwargs:
|
|
||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
|
||||||
train_dataset = train_dataset.filter(
|
|
||||||
drop_long,
|
|
||||||
batched=True,
|
|
||||||
**filter_map_kwargs,
|
|
||||||
**drop_long_kwargs,
|
|
||||||
)
|
|
||||||
if prior_len:
|
|
||||||
dropped = prior_len - len(train_dataset)
|
|
||||||
if dropped:
|
|
||||||
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
|
||||||
|
|
||||||
if eval_dataset:
|
|
||||||
try:
|
|
||||||
prior_len = len(eval_dataset)
|
|
||||||
except TypeError:
|
|
||||||
# handle iterable datasets case
|
|
||||||
prior_len = None
|
|
||||||
eval_dataset = eval_dataset.filter(
|
|
||||||
drop_long,
|
|
||||||
**filter_map_kwargs,
|
|
||||||
**drop_long_kwargs,
|
|
||||||
)
|
|
||||||
if prior_len:
|
|
||||||
dropped = prior_len - len(eval_dataset)
|
|
||||||
if dropped:
|
|
||||||
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
|
||||||
|
|
||||||
def drop_no_trainable_tokens(sample):
|
def drop_no_trainable_tokens(sample):
|
||||||
"""
|
"""
|
||||||
Drop samples if all labels are -100 (i.e., zero trainable tokens).
|
Drop samples if all labels are -100 (i.e., zero trainable tokens).
|
||||||
@@ -325,6 +274,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
# handle iterable datasets case
|
# handle iterable datasets case
|
||||||
prior_len = None
|
prior_len = None
|
||||||
|
filter_map_kwargs = {}
|
||||||
|
if not isinstance(train_dataset, IterableDataset):
|
||||||
|
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||||
|
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
||||||
@@ -622,7 +576,7 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
|
if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"):
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
|||||||
"num_labels": 1,
|
"num_labels": 1,
|
||||||
"chat_template": "alpaca",
|
"chat_template": "alpaca",
|
||||||
"reward_model": True,
|
"reward_model": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 2048,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
|
|||||||
Reference in New Issue
Block a user