* various bugfixes use latest tinyllama release check if val_set_size is empty first update sdp and xformers llama patches for updated upstream transformers fix system prompt when no input calculate total and total supervised tokens even when not sample packing * add fix for when eval size is estimated to be too small * should be len 1 for dataset length * add catchall kwargs
143 lines
4.7 KiB
Python
143 lines
4.7 KiB
Python
"""
|
|
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
|
"""
|
|
|
|
import warnings
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import transformers.models.llama.modeling_llama
|
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
|
|
|
|
|
def hijack_llama_sdp_attention():
|
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
|
sdp_attention_forward
|
|
)
|
|
|
|
|
|
def sdp_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
**kwargs, # pylint: disable=unused-argument
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
# pylint: disable=duplicate-code
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
if not hasattr(self, "pretraining_tp"):
|
|
self.pretraining_tp = 1
|
|
|
|
if self.pretraining_tp > 1:
|
|
key_value_slicing = (
|
|
self.num_key_value_heads * self.head_dim
|
|
) // self.pretraining_tp
|
|
query_slices = self.q_proj.weight.split(
|
|
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
|
)
|
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
|
|
query_states = [
|
|
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
|
]
|
|
query_states = torch.cat(query_states, dim=-1)
|
|
|
|
key_states = [
|
|
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
|
]
|
|
key_states = torch.cat(key_states, dim=-1)
|
|
|
|
value_states = [
|
|
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
|
]
|
|
value_states = torch.cat(value_states, dim=-1)
|
|
|
|
else:
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(
|
|
bsz, q_len, self.num_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
key_states = key_states.view(
|
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
value_states = value_states.view(
|
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
).transpose(1, 2)
|
|
# [bsz, q_len, nh, hd]
|
|
# [bsz, nh, q_len, hd]
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
query_states, key_states = apply_rotary_pos_emb(
|
|
query_states, key_states, cos, sin, position_ids
|
|
)
|
|
# [bsz, nh, t, hd]
|
|
|
|
if past_key_value is not None:
|
|
# reuse k, v, self_attention
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
# repeat k/v heads if n_kv_heads < n_heads
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
if output_attentions:
|
|
warnings.warn(
|
|
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
|
)
|
|
|
|
#
|
|
# sdp-attn start
|
|
#
|
|
|
|
with torch.backends.cuda.sdp_kernel():
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=attention_mask,
|
|
is_causal=False,
|
|
)
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
#
|
|
# sdp-attn end
|
|
#
|
|
|
|
if self.pretraining_tp > 1:
|
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
|
o_proj_slices = self.o_proj.weight.split(
|
|
self.hidden_size // self.pretraining_tp, dim=1
|
|
)
|
|
attn_output = sum(
|
|
F.linear(attn_output[i], o_proj_slices[i])
|
|
for i in range(self.pretraining_tp)
|
|
)
|
|
else:
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, None, past_key_value
|