Merge pull request #169 from NanoCode012/feat/landmark
Feat: Add landmark attention
This commit is contained in:
@@ -417,6 +417,8 @@ flash_attention: # require a100 for llama
|
|||||||
# whether to use scaled-dot-product attention
|
# whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
|
# Landmark attention (only llama)
|
||||||
|
landmark_attention:
|
||||||
|
|
||||||
# resume from a specific checkpoint dir
|
# resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
|
|||||||
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
||||||
|
LlamaForCausalLM,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"This version of transformers does not support Llama. Consider upgrading."
|
"This version of transformers does not support Llama. Consider upgrading."
|
||||||
@@ -83,37 +85,47 @@ def load_model(
|
|||||||
adapter="lora",
|
adapter="lora",
|
||||||
inference=False,
|
inference=False,
|
||||||
):
|
):
|
||||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
"""
|
"""
|
||||||
Load a model from a base model and a model type.
|
Load a model from a base model and a model type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
is_llama_derived_model = "llama" in base_model or (
|
cfg.is_llama_derived_model = "llama" in base_model or (
|
||||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||||
|
|
||||||
logging.info("patching with flash attention")
|
logging.info("patching with flash attention")
|
||||||
replace_llama_attn_with_flash_attn()
|
replace_llama_attn_with_flash_attn()
|
||||||
elif is_llama_derived_model and cfg.xformers_attention:
|
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with xformers attention")
|
logging.info("patching with xformers attention")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
elif is_llama_derived_model and cfg.sdp_attention:
|
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_sdp_attention,
|
hijack_llama_sdp_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with sdp attention")
|
logging.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
|
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
|
||||||
|
MEM_TOKEN,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("patching with landmark attention")
|
||||||
|
|
||||||
|
# TODO: Check if this would overwrite previous additional_special_tokens
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||||
|
|
||||||
if cfg.bf16:
|
if cfg.bf16:
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
@@ -145,7 +157,7 @@ def load_model(
|
|||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if cfg.gptq and is_llama_derived_model:
|
if cfg.gptq and cfg.is_llama_derived_model:
|
||||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@@ -183,7 +195,7 @@ def load_model(
|
|||||||
else True,
|
else True,
|
||||||
)
|
)
|
||||||
load_in_8bit = False
|
load_in_8bit = False
|
||||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||||
config = LlamaConfig.from_pretrained(base_model_config)
|
config = LlamaConfig.from_pretrained(base_model_config)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -235,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
else:
|
else:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
|
||||||
|
|
||||||
|
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||||
|
model.set_mem_id(mem_id)
|
||||||
|
|
||||||
|
logging.info("Adding landmark attention tokens to dataset")
|
||||||
|
|
||||||
|
for dataset in [train_dataset, eval_dataset]:
|
||||||
|
dataset = dataset.map(
|
||||||
|
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
|
||||||
|
batched=False,
|
||||||
|
num_proc=32,
|
||||||
|
)
|
||||||
|
|
||||||
trainer_cls = (
|
trainer_cls = (
|
||||||
OneCycleLRSchedulerTrainer
|
OneCycleLRSchedulerTrainer
|
||||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||||
|
|||||||
Reference in New Issue
Block a user