packed doc mask starts at 1, 0 means masked out

This commit is contained in:
Sunny Liu
2025-02-07 14:44:52 -05:00
parent d0e739da24
commit c0a1d205c7
2 changed files with 2 additions and 4 deletions

View File

@@ -130,10 +130,10 @@ def get_packed_mask_from_pos_ids(position_ids):
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
i * doc_mask[start_id : start_id + seq_len]
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = -100 * doc_mask[len(adjusted_row) :]
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
results.append(doc_mask)

View File

@@ -16,7 +16,6 @@ from axolotl.monkeypatch.flex_attn import (
)
from axolotl.monkeypatch.utils import (
get_packed_mask_from_pos_ids,
get_seqlens_from_pos_ids,
)
@@ -194,7 +193,6 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
# raise ValueError(f"{out['attention_mask'].shape}")
return out