From 92afa4fa272c9139c58e8c4527cedb61750cc3e1 Mon Sep 17 00:00:00 2001 From: Qingyang Wu Date: Mon, 9 Jun 2025 21:26:36 -0700 Subject: [PATCH] Fix the bug of position ids padding (#2739) [skip ci] * Update batching.py: fix the bug of position ids padding if position ids is padded with a long sequence of zeros, it will cause flash attention to crash * use alternate calculation for padding position_ids with a range --------- Co-authored-by: Wing Lian --- src/axolotl/utils/collators/batching.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 45facf832..d8414d117 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -81,9 +81,11 @@ class DataCollatorForSeq2Seq: padding_side = self.tokenizer.padding_side for feature in features: - remainder = [pad_token_id] * ( - max_feature_length - len(feature[feature_name]) - ) + remainder_len = max_feature_length - len(feature[feature_name]) + if feature_name == "position_ids": + remainder = list(range(remainder_len)) + else: + remainder = [pad_token_id] * remainder_len if isinstance(feature[feature_name], list): feature[feature_name] = ( feature[feature_name] + remainder