Merge pull request #169 from NanoCode012/feat/landmark
Feat: Add landmark attention
This commit is contained in:
@@ -417,6 +417,8 @@ flash_attention: # require a100 for llama
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
|
||||
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
@@ -83,37 +85,47 @@ def load_model(
|
||||
adapter="lora",
|
||||
inference=False,
|
||||
):
|
||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
)
|
||||
|
||||
if is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
|
||||
logging.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif is_llama_derived_model and cfg.sdp_attention:
|
||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
|
||||
MEM_TOKEN,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
|
||||
logging.info("patching with landmark attention")
|
||||
|
||||
# TODO: Check if this would overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.bf16:
|
||||
torch_dtype = torch.bfloat16
|
||||
@@ -145,7 +157,7 @@ def load_model(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
try:
|
||||
if cfg.gptq and is_llama_derived_model:
|
||||
if cfg.gptq and cfg.is_llama_derived_model:
|
||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -183,7 +195,7 @@ def load_model(
|
||||
else True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
@@ -235,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
else:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from functools import partial
|
||||
|
||||
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
|
||||
|
||||
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||
model.set_mem_id(mem_id)
|
||||
|
||||
logging.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = (
|
||||
OneCycleLRSchedulerTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||
|
||||
Reference in New Issue
Block a user