diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 579fbfbb3..39f3d0c04 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -169,19 +169,26 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if feature == "length": continue if feature == "attention_mask": - seq_len_list = [ - get_seqlens_from_pos_ids(item["position_ids"]) - for item in features_ + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features_) if feature in item ] - - out_features[i][feature] = np.concatenate(seq_len_list) + out_features[i][feature] = np.concatenate(arrays) else: arrays = [ np.array(item[feature]) for item in features_ if feature in item ] out_features[i][feature] = np.concatenate(arrays) - return super().__call__(out_features, return_tensors=return_tensors) + out = super().__call__(out_features, return_tensors=return_tensors) + + collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"]) + + doc_mask = + + out["attention_mask"] + + return out @dataclass