more test

This commit is contained in:
bursteratom
2025-02-02 00:48:15 -05:00
parent 9a43a0925d
commit 2319e5276d

View File

@@ -48,14 +48,14 @@ def create_block_causal_mask(
for seq_len in seq_lens[sample_idx] for seq_len in seq_lens[sample_idx]
] ]
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) """residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
block_attn_masks.append( block_attn_masks.append(
torch.tril( torch.tril(
torch.ones( torch.ones(
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
) )
) )
) )"""
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))