more test
This commit is contained in:
@@ -48,14 +48,14 @@ def create_block_causal_mask(
|
|||||||
for seq_len in seq_lens[sample_idx]
|
for seq_len in seq_lens[sample_idx]
|
||||||
]
|
]
|
||||||
|
|
||||||
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
"""residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
||||||
block_attn_masks.append(
|
block_attn_masks.append(
|
||||||
torch.tril(
|
torch.tril(
|
||||||
torch.ones(
|
torch.ones(
|
||||||
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
|
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)"""
|
||||||
|
|
||||||
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user