From ba88bc784012d7ee98cd8b8de78285bc7651aa7d Mon Sep 17 00:00:00 2001 From: bursteratom Date: Wed, 29 Jan 2025 00:25:25 -0500 Subject: [PATCH] wip flex block mask creation --- src/axolotl/utils/collators/batching.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 579fbfbb3..39f3d0c04 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -169,19 +169,26 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if feature == "length": continue if feature == "attention_mask": - seq_len_list = [ - get_seqlens_from_pos_ids(item["position_ids"]) - for item in features_ + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features_) if feature in item ] - - out_features[i][feature] = np.concatenate(seq_len_list) + out_features[i][feature] = np.concatenate(arrays) else: arrays = [ np.array(item[feature]) for item in features_ if feature in item ] out_features[i][feature] = np.concatenate(arrays) - return super().__call__(out_features, return_tensors=return_tensors) + out = super().__call__(out_features, return_tensors=return_tensors) + + collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"]) + + doc_mask = + + out["attention_mask"] + + return out @dataclass