attempt - mask padding
This commit is contained in:
@@ -131,7 +131,8 @@ def packed_block_causal_mask(
|
|||||||
"""
|
"""
|
||||||
causal_mask = q_idx >= kv_idx
|
causal_mask = q_idx >= kv_idx
|
||||||
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
|
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
|
||||||
return causal_mask & document_mask & (q_idx < totalseqlens[b])
|
finite_mask = q_idx < totalseqlens[b]
|
||||||
|
return causal_mask & document_mask & finite_mask
|
||||||
|
|
||||||
return create_block_causal_mask_flex(
|
return create_block_causal_mask_flex(
|
||||||
mask_mod,
|
mask_mod,
|
||||||
|
|||||||
Reference in New Issue
Block a user