pre-commit formatting fixes
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Wing Lian
2023-08-05 22:46:02 -04:00
parent 64852ae15a
commit 9793faf6dc

View File

@@ -7,9 +7,9 @@ import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
from torch import nn from torch import nn
import torch.nn.functional as F
try: try:
import xformers.ops import xformers.ops
@@ -39,44 +39,48 @@ def xformers_forward(
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if not hasattr(self, 'pretraining_tp'): if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1 self.pretraining_tp = 1
if self.pretraining_tp > 1: if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp key_value_slicing = (
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1) query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1) key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
query_states = ( query_states = query_states.view(
query_states bsz, q_len, self.num_heads, self.head_dim
.view(bsz, q_len, self.num_heads, self.head_dim) ).transpose(1, 2)
.transpose(1, 2) key_states = key_states.view(
) bsz, q_len, self.num_key_value_heads, self.head_dim
key_states = ( ).transpose(1, 2)
key_states value_states = value_states.view(
.view(bsz, q_len, self.num_key_value_heads, self.head_dim) bsz, q_len, self.num_key_value_heads, self.head_dim
.transpose(1, 2) ).transpose(1, 2)
)
value_states = (
value_states
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
@@ -98,8 +102,12 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = transformers.models.llama.modeling_llama.repeat_kv(key_states, self.num_key_value_groups) key_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states = transformers.models.llama.modeling_llama.repeat_kv(value_states, self.num_key_value_groups) key_states, self.num_key_value_groups
)
value_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states, self.num_key_value_groups
)
# We only apply xformers optimizations if we don't need to output the whole attention matrix # We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
@@ -157,17 +165,22 @@ def xformers_forward(
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
#end x-formers vs. not x-formers if-else block # end x-formers vs. not x-formers if-else block
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1: if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) o_proj_slices = self.o_proj.weight.split(
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else: else:
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value