fix
This commit is contained in:
@@ -57,12 +57,7 @@ class RingAttnFunc(str, Enum):
|
|||||||
def register_ring_attn(
|
def register_ring_attn(
|
||||||
sequence_parallel_degree: int,
|
sequence_parallel_degree: int,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
<<<<<<< HEAD
|
|
||||||
ring_attn_func: RingAttnFunc | None,
|
ring_attn_func: RingAttnFunc | None,
|
||||||
=======
|
|
||||||
sample_packing: bool,
|
|
||||||
ring_attn_func: str | None,
|
|
||||||
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
@@ -125,10 +120,6 @@ def register_ring_attn(
|
|||||||
substitute_hf_flash_attn(
|
substitute_hf_flash_attn(
|
||||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||||
)
|
)
|
||||||
<<<<<<< HEAD
|
|
||||||
=======
|
|
||||||
# TODO(djsaunde): handle other ring attn funcs in this branch
|
|
||||||
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
|
|
||||||
elif ring_attn_func in [
|
elif ring_attn_func in [
|
||||||
RingAttnFunc.BATCH_RING,
|
RingAttnFunc.BATCH_RING,
|
||||||
RingAttnFunc.BATCH_ZIGZAG,
|
RingAttnFunc.BATCH_ZIGZAG,
|
||||||
|
|||||||
Reference in New Issue
Block a user