Merge pull request #169 from NanoCode012/feat/landmark

Feat: Add landmark attention
This commit is contained in:
NanoCode012
2023-06-10 07:26:06 +09:00
committed by GitHub
4 changed files with 1635 additions and 8 deletions

View File

@@ -417,6 +417,8 @@ flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention # whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:

File diff suppressed because it is too large Load Diff

View File

@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
) )
try: try:
from transformers import LlamaForCausalLM from transformers import ( # pylint: disable=unused-import # noqa: F401
LlamaForCausalLM,
)
except ImportError: except ImportError:
logging.warning( logging.warning(
"This version of transformers does not support Llama. Consider upgrading." "This version of transformers does not support Llama. Consider upgrading."
@@ -83,37 +85,47 @@ def load_model(
adapter="lora", adapter="lora",
inference=False, inference=False,
): ):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model from a base model and a model type.
""" """
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
is_llama_derived_model = "llama" in base_model or ( cfg.is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower() cfg.model_type and "llama" in cfg.model_type.lower()
) )
if is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False: if cfg.device not in ["mps", "cpu"] and inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention") logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn() replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention: elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention, hijack_llama_attention,
) )
logging.info("patching with xformers attention") logging.info("patching with xformers attention")
hijack_llama_attention() hijack_llama_attention()
elif is_llama_derived_model and cfg.sdp_attention: elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention, hijack_llama_sdp_attention,
) )
logging.info("patching with sdp attention") logging.info("patching with sdp attention")
hijack_llama_sdp_attention() hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
MEM_TOKEN,
LlamaForCausalLM,
)
logging.info("patching with landmark attention")
# TODO: Check if this would overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.bf16: if cfg.bf16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
@@ -145,7 +157,7 @@ def load_model(
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
try: try:
if cfg.gptq and is_llama_derived_model: if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@@ -183,7 +195,7 @@ def load_model(
else True, else True,
) )
load_in_8bit = False load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals(): elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
config = LlamaConfig.from_pretrained(base_model_config) config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,

View File

@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib import importlib
import logging
import math import math
import os import os
import sys import sys
@@ -235,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else: else:
data_collator_kwargs["pad_to_multiple_of"] = 8 data_collator_kwargs["pad_to_multiple_of"] = 8
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
model.set_mem_id(mem_id)
logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
batched=False,
num_proc=32,
)
trainer_cls = ( trainer_cls = (
OneCycleLRSchedulerTrainer OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")