packed doc mask starts at 1, 0 means masked out
This commit is contained in:
@@ -130,10 +130,10 @@ def get_packed_mask_from_pos_ids(position_ids):
|
||||
for i, seq_len in enumerate(seq_lengths):
|
||||
start_id = start_indices[i]
|
||||
doc_mask[start_id : start_id + seq_len] = (
|
||||
i * doc_mask[start_id : start_id + seq_len]
|
||||
(i+1) * doc_mask[start_id : start_id + seq_len]
|
||||
)
|
||||
if padding_length:
|
||||
doc_mask[len(adjusted_row) :] = -100 * doc_mask[len(adjusted_row) :]
|
||||
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
|
||||
|
||||
results.append(doc_mask)
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from axolotl.monkeypatch.flex_attn import (
|
||||
)
|
||||
from axolotl.monkeypatch.utils import (
|
||||
get_packed_mask_from_pos_ids,
|
||||
get_seqlens_from_pos_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -194,7 +193,6 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
||||
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
|
||||
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||
# raise ValueError(f"{out['attention_mask'].shape}")
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user