fix
This commit is contained in:
@@ -57,12 +57,7 @@ class RingAttnFunc(str, Enum):
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
<<<<<<< HEAD
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
=======
|
||||
sample_packing: bool,
|
||||
ring_attn_func: str | None,
|
||||
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
|
||||
):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
@@ -125,10 +120,6 @@ def register_ring_attn(
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
# TODO(djsaunde): handle other ring attn funcs in this branch
|
||||
>>>>>>> 8799e9a6 (adding all batch ring-flash-attn methods via single adapter)
|
||||
elif ring_attn_func in [
|
||||
RingAttnFunc.BATCH_RING,
|
||||
RingAttnFunc.BATCH_ZIGZAG,
|
||||
|
||||
Reference in New Issue
Block a user