Fix the bug of position ids padding (#2739) [skip ci]
* Update batching.py: fix the bug of position ids padding if position ids is padded with a long sequence of zeros, it will cause flash attention to crash * use alternate calculation for padding position_ids with a range --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -81,9 +81,11 @@ class DataCollatorForSeq2Seq:
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
for feature in features:
|
||||
remainder = [pad_token_id] * (
|
||||
max_feature_length - len(feature[feature_name])
|
||||
)
|
||||
remainder_len = max_feature_length - len(feature[feature_name])
|
||||
if feature_name == "position_ids":
|
||||
remainder = list(range(remainder_len))
|
||||
else:
|
||||
remainder = [pad_token_id] * remainder_len
|
||||
if isinstance(feature[feature_name], list):
|
||||
feature[feature_name] = (
|
||||
feature[feature_name] + remainder
|
||||
|
||||
Reference in New Issue
Block a user