figure out slight diff from flash result

This commit is contained in:
bursteratom
2025-02-02 01:45:54 -05:00
parent 0ebab63309
commit d3ea379a23
3 changed files with 10 additions and 8 deletions

View File

@@ -94,7 +94,7 @@ def _get_document_ids_from_seq_lens(
def packed_block_causal_mask( def packed_block_causal_mask(
seq_lens: list[torch.Tensor], max_seq_len: int seq_lens: list[torch.Tensor], totalseqlens: list[int]
) -> _MaskType: ) -> _MaskType:
""" """
Create a block causal document mask for a batch of packed sequences. If Create a block causal document mask for a batch of packed sequences. If
@@ -113,7 +113,7 @@ def packed_block_causal_mask(
""" """
document_ids = _get_document_ids_from_seq_lens(seq_lens) document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size, _ = document_ids.shape batch_size , max_seq_len = document_ids
document_ids = document_ids.to("cuda") document_ids = document_ids.to("cuda")
# Instead of passing a tensor mask, flex attention requires a mask_mod function # Instead of passing a tensor mask, flex attention requires a mask_mod function
@@ -131,7 +131,7 @@ def packed_block_causal_mask(
""" """
causal_mask = q_idx >= kv_idx causal_mask = q_idx >= kv_idx
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx] document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
return causal_mask & document_mask return causal_mask & document_mask & (q_idx < totalseqlens[b])
return create_block_causal_mask_flex( return create_block_causal_mask_flex(
mask_mod, mask_mod,

View File

@@ -103,6 +103,7 @@ def get_seqlens_from_pos_ids(position_ids):
device = position_ids.device device = position_ids.device
results = [] results = []
totalseqlens = []
for row in position_ids: for row in position_ids:
# Count the number of consecutive zeros from the right side # Count the number of consecutive zeros from the right side
@@ -128,7 +129,7 @@ def get_seqlens_from_pos_ids(position_ids):
# Calculate the sequence lengths # Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1] seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths # Append the padding length to the sequence lengths
"""if padding_length: if padding_length:
seq_lengths = torch.cat( seq_lengths = torch.cat(
[ [
seq_lengths, seq_lengths,
@@ -138,11 +139,12 @@ def get_seqlens_from_pos_ids(position_ids):
device=device, device=device,
), ),
] ]
)""" )
results.append(seq_lengths) results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results , max_seq_len return results , totalseqlens
def get_cu_seqlens_from_pos_ids(position_ids): def get_cu_seqlens_from_pos_ids(position_ids):

View File

@@ -179,8 +179,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out_features[i][feature] = np.concatenate(arrays) out_features[i][feature] = np.concatenate(arrays)
out = super().__call__(out_features, return_tensors=return_tensors) out = super().__call__(out_features, return_tensors=return_tensors)
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"]) collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, max_seq_len) out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len) # out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
# raise ValueError(f"{out['attention_mask'].shape}") # raise ValueError(f"{out['attention_mask'].shape}")
return out return out