This commit is contained in:
Dan Saunders
2025-04-14 14:41:52 +00:00
parent 4ae8df16a9
commit 5306c6acbb

View File

@@ -57,12 +57,7 @@ class RingAttnFunc(str, Enum):
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
<<<<<<< HEAD
ring_attn_func: RingAttnFunc | None,
=======
sample_packing: bool,
ring_attn_func: str | None,
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
):
"""
Create ring attention group and substitute flash attn with ring flash attn.
@@ -125,10 +120,6 @@ def register_ring_attn(
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
<<<<<<< HEAD
=======
# TODO(djsaunde): handle other ring attn funcs in this branch
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
elif ring_attn_func in [
RingAttnFunc.BATCH_RING,
RingAttnFunc.BATCH_ZIGZAG,