Catch configs without pretraining_tp
This commit is contained in:
@@ -39,6 +39,9 @@ def xformers_forward(
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if not hasattr(self, 'pretraining_tp'):
|
||||||
|
self.pretraining_tp = 1
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
if self.pretraining_tp > 1:
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
|
||||||
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
|
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user