Catch configs without pretraining_tp

This commit is contained in:
ssmi153
2023-08-05 11:45:12 +12:00
parent a300a4db1d
commit 1fed74b1d9

View File

@@ -39,6 +39,9 @@ def xformers_forward(
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, 'pretraining_tp'):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)