diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 3b35f29dd..b6bb159a1 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -130,10 +130,10 @@ def get_packed_mask_from_pos_ids(position_ids): for i, seq_len in enumerate(seq_lengths): start_id = start_indices[i] doc_mask[start_id : start_id + seq_len] = ( - i * doc_mask[start_id : start_id + seq_len] + (i+1) * doc_mask[start_id : start_id + seq_len] ) if padding_length: - doc_mask[len(adjusted_row) :] = -100 * doc_mask[len(adjusted_row) :] + doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :] results.append(doc_mask) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 9e247877b..72b4a5475 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -16,7 +16,6 @@ from axolotl.monkeypatch.flex_attn import ( ) from axolotl.monkeypatch.utils import ( get_packed_mask_from_pos_ids, - get_seqlens_from_pos_ids, ) @@ -194,7 +193,6 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): # out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens) out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"]) # out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) - # raise ValueError(f"{out['attention_mask'].shape}") return out