packed doc mask starts at 1, 0 means masked out
This commit is contained in:
@@ -130,10 +130,10 @@ def get_packed_mask_from_pos_ids(position_ids):
|
|||||||
for i, seq_len in enumerate(seq_lengths):
|
for i, seq_len in enumerate(seq_lengths):
|
||||||
start_id = start_indices[i]
|
start_id = start_indices[i]
|
||||||
doc_mask[start_id : start_id + seq_len] = (
|
doc_mask[start_id : start_id + seq_len] = (
|
||||||
i * doc_mask[start_id : start_id + seq_len]
|
(i+1) * doc_mask[start_id : start_id + seq_len]
|
||||||
)
|
)
|
||||||
if padding_length:
|
if padding_length:
|
||||||
doc_mask[len(adjusted_row) :] = -100 * doc_mask[len(adjusted_row) :]
|
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
|
||||||
|
|
||||||
results.append(doc_mask)
|
results.append(doc_mask)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from axolotl.monkeypatch.flex_attn import (
|
|||||||
)
|
)
|
||||||
from axolotl.monkeypatch.utils import (
|
from axolotl.monkeypatch.utils import (
|
||||||
get_packed_mask_from_pos_ids,
|
get_packed_mask_from_pos_ids,
|
||||||
get_seqlens_from_pos_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -194,7 +193,6 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
||||||
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
|
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
|
||||||
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||||
# raise ValueError(f"{out['attention_mask'].shape}")
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user