make doc mask instead of the whole block mask in collator
This commit is contained in:
@@ -95,6 +95,51 @@ def get_cu_seqlens(attn_mask):
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def get_packed_mask_from_pos_ids(position_ids):
|
||||
if len(position_ids.shape) == 1:
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
device = position_ids.device
|
||||
results = []
|
||||
|
||||
for i, row in enumerate(position_ids):
|
||||
# Count the number of consecutive zeros from the right side
|
||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||
|
||||
# Adjust the row to exclude padding
|
||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||
|
||||
# Find where the position resets to 0 (indicating a new sequence)
|
||||
seq_starts = torch.cat(
|
||||
[
|
||||
torch.tensor([True], dtype=torch.bool, device=device),
|
||||
adjusted_row[1:] == 0,
|
||||
]
|
||||
)
|
||||
# Get the indices where the sequence starts
|
||||
start_indices = torch.cat(
|
||||
[
|
||||
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||
]
|
||||
)
|
||||
# Calculate the sequence lengths
|
||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||
# Append the padding length to the sequence lengths
|
||||
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
|
||||
for i, seq_len in enumerate(seq_lengths):
|
||||
start_id = start_indices[i]
|
||||
doc_mask[start_id : start_id + seq_len] = (
|
||||
i * doc_mask[start_id : start_id + seq_len]
|
||||
)
|
||||
if padding_length:
|
||||
doc_mask[len(adjusted_row) :] = -100 * doc_mask[seq_lengths[-1] :]
|
||||
|
||||
results.append(doc_mask)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
||||
if len(position_ids.shape) == 1:
|
||||
@@ -144,7 +189,7 @@ def get_seqlens_from_pos_ids(position_ids):
|
||||
results.append(seq_lengths)
|
||||
totalseqlens.append(len(adjusted_row))
|
||||
|
||||
return results , torch.tensor(totalseqlens, dtype=torch.int32, device=device)
|
||||
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
|
||||
@@ -14,7 +14,10 @@ from axolotl.monkeypatch.flex_attn import (
|
||||
create_block_causal_mask,
|
||||
packed_block_causal_mask,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.utils import (
|
||||
get_packed_mask_from_pos_ids,
|
||||
get_seqlens_from_pos_ids,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -170,7 +173,15 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature in {"length", "attention_mask"}:
|
||||
if feature == "length":
|
||||
continue
|
||||
elif feature == "attention_mask":
|
||||
"""arrays = [
|
||||
i * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)"""
|
||||
continue
|
||||
else:
|
||||
arrays = [
|
||||
@@ -179,8 +190,9 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
|
||||
out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
||||
# collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
|
||||
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
||||
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
|
||||
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||
# raise ValueError(f"{out['attention_mask'].shape}")
|
||||
return out
|
||||
@@ -243,4 +255,4 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
arrays = [np.array(item) for item in features[feature]]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
Reference in New Issue
Block a user