remove landmark attn and xpos rope implementations (#1010)
This commit is contained in:
@@ -798,11 +798,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
|||||||
# Whether to use scaled-dot-product attention
|
# Whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
# Landmark attention (only llama)
|
|
||||||
landmark_attention:
|
|
||||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
|
||||||
# LLaMA only
|
|
||||||
xpos_rope:
|
|
||||||
|
|
||||||
# Resume from a specific checkpoint dir
|
# Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
|
|||||||
@@ -103,14 +103,6 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.landmark_attention:
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
|
||||||
|
|
||||||
set_model_mem_id(model, tokenizer)
|
|
||||||
model.set_mem_cache_args(
|
|
||||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
model = model.to(cfg.device)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -176,14 +168,6 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.landmark_attention:
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
|
||||||
|
|
||||||
set_model_mem_id(model, tokenizer)
|
|
||||||
model.set_mem_cache_args(
|
|
||||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
model = model.to(cfg.device)
|
||||||
|
|
||||||
def generate(instruction):
|
def generate(instruction):
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import math
|
|||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial, wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -780,26 +780,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
|
||||||
add_mem_tokens,
|
|
||||||
get_mem_id,
|
|
||||||
set_model_mem_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
set_model_mem_id(self.model, self.tokenizer)
|
|
||||||
|
|
||||||
LOG.info("Adding landmark attention tokens to dataset")
|
|
||||||
|
|
||||||
for dataset in [self.train_dataset, self.eval_dataset]:
|
|
||||||
dataset = dataset.map(
|
|
||||||
partial(
|
|
||||||
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
|
|
||||||
),
|
|
||||||
batched=False,
|
|
||||||
num_proc=32,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
trainer_kwargs, trainer_cls
|
trainer_kwargs, trainer_cls
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,94 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
"""
|
|
||||||
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
class XposRotaryEmbedding(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
base=10000,
|
|
||||||
device=None,
|
|
||||||
scale_base=2048,
|
|
||||||
use_xpos=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
|
||||||
self.scale_base = scale_base
|
|
||||||
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
|
||||||
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
|
||||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
self.register_buffer("freqs_cached", freqs, persistent=False)
|
|
||||||
|
|
||||||
if not use_xpos:
|
|
||||||
self.register_buffer("scale", None)
|
|
||||||
self.register_buffer("scale_cached", torch.ones(1))
|
|
||||||
return
|
|
||||||
|
|
||||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
||||||
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
|
||||||
scale_cached = scale ** rearrange(power, "n -> n 1")
|
|
||||||
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
|
||||||
|
|
||||||
self.register_buffer("scale", scale, persistent=False)
|
|
||||||
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
seq_len,
|
|
||||||
):
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
|
||||||
self.inv_freq
|
|
||||||
)
|
|
||||||
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
|
||||||
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
self.register_buffer("freqs_cached", freqs)
|
|
||||||
|
|
||||||
if self.scale is None:
|
|
||||||
self.register_buffer(
|
|
||||||
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
|
||||||
|
|
||||||
power = (t - (seq_len // 2)) / self.scale_base
|
|
||||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
|
||||||
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
|
||||||
self.register_buffer("scale_cached", scale)
|
|
||||||
|
|
||||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
|
||||||
freqs = freqs[position_ids, :]
|
|
||||||
if scale.shape[-1] != 1:
|
|
||||||
scale = scale[position_ids, :]
|
|
||||||
|
|
||||||
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
|
||||||
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
|
||||||
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_rope_with_xpos_rope():
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
|
||||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
|
||||||
@@ -247,17 +247,6 @@ def load_model(
|
|||||||
|
|
||||||
LOG.info("patching with sdp attention")
|
LOG.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
|
||||||
MEM_TOKEN,
|
|
||||||
patch_llama_with_landmark_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with landmark attention")
|
|
||||||
patch_llama_with_landmark_attn()
|
|
||||||
|
|
||||||
# Note: This might overwrite previous additional_special_tokens
|
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
|
||||||
|
|
||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
@@ -279,14 +268,6 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mixtral_attn_with_multipack_flash_attn()
|
replace_mixtral_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
|
||||||
replace_llama_rope_with_xpos_rope,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with xpos rope")
|
|
||||||
replace_llama_rope_with_xpos_rope()
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.is_llama_derived_model
|
cfg.is_llama_derived_model
|
||||||
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
||||||
|
|||||||
Reference in New Issue
Block a user