Feat: Add landmark attention

This commit is contained in:
NanoCode012
2023-06-09 12:54:08 +09:00
parent febe902517
commit 55b8542de8
4 changed files with 1635 additions and 7 deletions

View File

@@ -416,6 +416,8 @@ flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# resume from a specific checkpoint dir
resume_from_checkpoint:

File diff suppressed because it is too large Load Diff

View File

@@ -83,37 +83,47 @@ def load_model(
adapter="lora",
inference=False,
):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
is_llama_derived_model = "llama" in base_model or (
cfg.is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
if is_llama_derived_model and cfg.flash_attention:
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention:
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
logging.info("patching with xformers attention")
hijack_llama_attention()
elif is_llama_derived_model and cfg.sdp_attention:
elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention,
)
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
MEM_TOKEN,
hijack_llama_landmark_attn,
)
logging.info("patching with landmark attention")
hijack_llama_landmark_attn()
tokenizer.add_special_tokens({"mem_token": MEM_TOKEN})
if cfg.bf16:
torch_dtype = torch.bfloat16
@@ -145,7 +155,7 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
try:
if cfg.gptq and is_llama_derived_model:
if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
@@ -183,7 +193,7 @@ def load_model(
else True,
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,

View File

@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions"""
import importlib
import logging
import math
import os
import sys
@@ -235,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
model.set_mem_id(mem_id)
logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
batched=False,
num_proc=32,
)
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")