more test
This commit is contained in:
@@ -51,7 +51,7 @@ def create_block_causal_mask(
|
||||
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
||||
block_attn_masks.append(
|
||||
torch.tril(
|
||||
torch.zeros(
|
||||
torch.ones(
|
||||
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user