stack
This commit is contained in:
@@ -137,7 +137,7 @@ def get_packed_mask_from_pos_ids(position_ids):
|
|||||||
|
|
||||||
results.append(doc_mask)
|
results.append(doc_mask)
|
||||||
|
|
||||||
return results
|
return torch.stack(results)
|
||||||
|
|
||||||
|
|
||||||
def get_seqlens_from_pos_ids(position_ids):
|
def get_seqlens_from_pos_ids(position_ids):
|
||||||
|
|||||||
Reference in New Issue
Block a user