figure out slight diff from flash result
This commit is contained in:
@@ -94,7 +94,7 @@ def _get_document_ids_from_seq_lens(
|
|||||||
|
|
||||||
|
|
||||||
def packed_block_causal_mask(
|
def packed_block_causal_mask(
|
||||||
seq_lens: list[torch.Tensor], max_seq_len: int
|
seq_lens: list[torch.Tensor], totalseqlens: list[int]
|
||||||
) -> _MaskType:
|
) -> _MaskType:
|
||||||
"""
|
"""
|
||||||
Create a block causal document mask for a batch of packed sequences. If
|
Create a block causal document mask for a batch of packed sequences. If
|
||||||
@@ -113,7 +113,7 @@ def packed_block_causal_mask(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||||
batch_size, _ = document_ids.shape
|
batch_size , max_seq_len = document_ids
|
||||||
document_ids = document_ids.to("cuda")
|
document_ids = document_ids.to("cuda")
|
||||||
|
|
||||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||||
@@ -131,7 +131,7 @@ def packed_block_causal_mask(
|
|||||||
"""
|
"""
|
||||||
causal_mask = q_idx >= kv_idx
|
causal_mask = q_idx >= kv_idx
|
||||||
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
|
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
|
||||||
return causal_mask & document_mask
|
return causal_mask & document_mask & (q_idx < totalseqlens[b])
|
||||||
|
|
||||||
return create_block_causal_mask_flex(
|
return create_block_causal_mask_flex(
|
||||||
mask_mod,
|
mask_mod,
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ def get_seqlens_from_pos_ids(position_ids):
|
|||||||
|
|
||||||
device = position_ids.device
|
device = position_ids.device
|
||||||
results = []
|
results = []
|
||||||
|
totalseqlens = []
|
||||||
|
|
||||||
for row in position_ids:
|
for row in position_ids:
|
||||||
# Count the number of consecutive zeros from the right side
|
# Count the number of consecutive zeros from the right side
|
||||||
@@ -128,7 +129,7 @@ def get_seqlens_from_pos_ids(position_ids):
|
|||||||
# Calculate the sequence lengths
|
# Calculate the sequence lengths
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||||
# Append the padding length to the sequence lengths
|
# Append the padding length to the sequence lengths
|
||||||
"""if padding_length:
|
if padding_length:
|
||||||
seq_lengths = torch.cat(
|
seq_lengths = torch.cat(
|
||||||
[
|
[
|
||||||
seq_lengths,
|
seq_lengths,
|
||||||
@@ -138,11 +139,12 @@ def get_seqlens_from_pos_ids(position_ids):
|
|||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)"""
|
)
|
||||||
|
|
||||||
results.append(seq_lengths)
|
results.append(seq_lengths)
|
||||||
|
totalseqlens.append(len(adjusted_row))
|
||||||
|
|
||||||
return results , max_seq_len
|
return results , totalseqlens
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
|
|||||||
@@ -179,8 +179,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
|
collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
|
||||||
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, max_seq_len)
|
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
||||||
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||||
# raise ValueError(f"{out['attention_mask'].shape}")
|
# raise ValueError(f"{out['attention_mask'].shape}")
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user