Catch configs without pretraining_tp
This commit is contained in:
@@ -39,6 +39,9 @@ def xformers_forward(
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, 'pretraining_tp'):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user