more test
This commit is contained in:
@@ -50,8 +50,10 @@ def create_block_causal_mask(
|
|||||||
|
|
||||||
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
||||||
block_attn_masks.append(
|
block_attn_masks.append(
|
||||||
torch.zeros(
|
torch.tril(
|
||||||
residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device
|
torch.zeros(
|
||||||
|
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
|
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
|
||||||
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
||||||
out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||||
raise ValueError(f"{out['attention_mask'].shape}")
|
# raise ValueError(f"{out['attention_mask'].shape}")
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user