This commit is contained in:
Dan Saunders
2025-04-14 14:41:52 +00:00
parent 4ae8df16a9
commit 5306c6acbb

View File

@@ -57,12 +57,7 @@ class RingAttnFunc(str, Enum):
def register_ring_attn( def register_ring_attn(
sequence_parallel_degree: int, sequence_parallel_degree: int,
heads_k_stride: int | None, heads_k_stride: int | None,
<<<<<<< HEAD
ring_attn_func: RingAttnFunc | None, ring_attn_func: RingAttnFunc | None,
=======
sample_packing: bool,
ring_attn_func: str | None,
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
): ):
""" """
Create ring attention group and substitute flash attn with ring flash attn. Create ring attention group and substitute flash attn with ring flash attn.
@@ -125,10 +120,6 @@ def register_ring_attn(
substitute_hf_flash_attn( substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
) )
<<<<<<< HEAD
=======
# TODO(djsaunde): handle other ring attn funcs in this branch
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
elif ring_attn_func in [ elif ring_attn_func in [
RingAttnFunc.BATCH_RING, RingAttnFunc.BATCH_RING,
RingAttnFunc.BATCH_ZIGZAG, RingAttnFunc.BATCH_ZIGZAG,