From 3f8a43cab69955e8a068d1858c3b9c48a55ffdec Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 3 Mar 2025 19:59:10 +0000 Subject: [PATCH] adding easy_context as integration for now --- requirements.txt | 4 + .../integrations/easy_context/__init__.py | 96 ++ .../easy_context/dist_flash_attn/README.md | 11 + .../dist_flash_attn/async_communication.py | 776 ++++++++++ .../dist_flash_attn/lightseq_async_attn.py | 1209 ++++++++++++++++ .../lightseq_async_attn_varlen.py | 1273 +++++++++++++++++ .../dist_flash_attn/monkey_patch.py | 755 ++++++++++ .../dist_flash_attn/prepare_input.py | 34 + .../easy_context/ulysses_attn/monkey_patch.py | 114 ++ .../ulysses_attn/prepare_inputs.py | 44 + .../monkey_patch.py | 95 ++ .../easy_context/usp/monkey_patch.py | 114 ++ .../easy_context/usp/prepare_inputs.py | 58 + .../zigzag_ring_attn/monkey_patch.py | 114 ++ .../zigzag_ring_attn/prepare_inputs.py | 40 + src/axolotl/monkeypatch/sequence_parallel.py | 123 ++ src/axolotl/utils/models.py | 14 + 17 files changed, 4874 insertions(+) create mode 100644 src/axolotl/integrations/easy_context/__init__.py create mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/README.md create mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py create mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py create mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py create mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py create mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py create mode 100644 src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py create mode 100644 src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py create mode 100644 src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py create mode 100644 src/axolotl/integrations/easy_context/usp/monkey_patch.py create mode 100644 src/axolotl/integrations/easy_context/usp/prepare_inputs.py create mode 100644 src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py create mode 100644 src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py create mode 100644 src/axolotl/monkeypatch/sequence_parallel.py diff --git a/requirements.txt b/requirements.txt index 495f43af6..02ecd56f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,6 +36,7 @@ einops colorama numba numpy>=1.24.4,<=2.0.1 + # qlora things evaluate==0.4.1 scipy @@ -64,3 +65,6 @@ schedulefree==1.3.0 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 + +# for sequence parallelism +ring-flash-attn>=0.1.4 diff --git a/src/axolotl/integrations/easy_context/__init__.py b/src/axolotl/integrations/easy_context/__init__.py new file mode 100644 index 000000000..93d38a2e9 --- /dev/null +++ b/src/axolotl/integrations/easy_context/__init__.py @@ -0,0 +1,96 @@ +import logging + +from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama +from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs +from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama +from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs +from .usp.monkey_patch import apply_usp_attn_monkey_patch_llama +from .usp.prepare_inputs import prepare_usp_attn_inputs +from .zigzag_ring_attn.monkey_patch import ( + apply_zigzag_ring_attn_monkey_patch_llama, + apply_zigzag_ring_attn_monkey_patch_mistral, +) +from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs + +logger = logging.getLogger(__name__) + + +def prepare_seq_parallel_inputs( + seq_algo, + input_ids, + position_ids, + target_ids, + rank, + world_size, + device, + *args, + **kwargs, +): + if seq_algo == "zigzag_ring_attn": + return prepare_zigzag_ring_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "dist_flash_attn": + return prepare_dist_flash_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "ulysses_attn": + return prepare_ulysses_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device + ) + elif seq_algo == "usp_attn": + ring_degree = kwargs.get("ring_degree", 1) + ulysses_degree = world_size // ring_degree + logger.info( + f"Applying USP: Ring degree: {ring_degree}, Ulysses degree: {ulysses_degree}" + ) + return prepare_usp_attn_inputs( + input_ids, + position_ids, + target_ids, + rank, + world_size, + device, + ulysses_degree, + ring_degree, + ) + elif seq_algo == "data_parallel": + return { + "local_input_ids": input_ids.to(device), + "local_position_ids": position_ids.to(device), + "local_target_ids": target_ids.to(device), + } + else: + raise ValueError(f"Invalid seq_algo: {seq_algo}") + + +def apply_seq_parallel_monkey_patch(seq_algo, model): + assert seq_algo in [ + "zigzag_ring_attn", + "dist_flash_attn", + "ulysses_attn", + "data_parallel", + "usp_attn", + ], f"Invalid seq_algo: {seq_algo}" + assert model in ["llama", "mistral"], f"Invalid model: {model}" + if seq_algo == "data_parallel": + return + elif seq_algo == "zigzag_ring_attn" and model == "llama": + apply_zigzag_ring_attn_monkey_patch_llama() + elif seq_algo == "zigzag_ring_attn" and model == "mistral": + apply_zigzag_ring_attn_monkey_patch_mistral() + elif seq_algo == "dist_flash_attn" and model == "llama": + apply_dist_flash_attn_monkey_patch_llama() + elif seq_algo == "ulysses_attn" and model == "llama": + apply_ulysses_attn_monkey_patch_llama() + elif seq_algo == "usp_attn" and model == "llama": + apply_usp_attn_monkey_patch_llama() + else: + raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}") + + +def prepare_dataloader(seq_algo, dataloader, accelerator): + if seq_algo == "data_parallel": + return accelerator.prepare(dataloader) + else: + return dataloader diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/README.md b/src/axolotl/integrations/easy_context/dist_flash_attn/README.md new file mode 100644 index 000000000..9ddeb332e --- /dev/null +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/README.md @@ -0,0 +1,11 @@ +# LightSeq +Taken from https://github.com/RulinShao/LightSeq. All credits to the authors. + +``` +@article{li2023lightseq, + title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS}, + author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao}, + journal={arXiv preprint arXiv:2310.03294}, + year={2023} +} +``` diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py b/src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py new file mode 100644 index 000000000..1ea4aded1 --- /dev/null +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py @@ -0,0 +1,776 @@ +import math +import os +import threading + +import torch +import torch.distributed as dist +from torch.distributed import P2POp, batch_isend_irecv, irecv, isend + +# Sequence parallel group that the current rank belongs to. +_SEQUENCE_PARALLEL_GROUP = None + +# These values enable us to change the sequence parallel sizes on the fly. +_SEQUENCE_PARALLEL_SIZE = None +_SEQUENCE_PARALLEL_RANK = None + +# Global buffer for P2P +_PEER_Q = None +_PEER_K = None +_PEER_V = None +_PEER_M = None +_PEER_L = None +_PEER_O = None +_PEER_Q_BWD = None +_PEER_K_BWD = None +_PEER_V_BWD = None +_PEER_O_BWD = None + +_DELTA_DQ = None +_PEER_L = None +_DELTA_DK = None +_DELTA_DV = None +_DK_DELTA_FROM_PEER = None +_DV_DELTA_FROM_PEER = None +_PEER_DO = None + + +_fwd_send_volume = 0 +_fwd_recv_volume = 0 +_bwd_send_volume = 0 +_bwd_recv_volume = 0 + + +def initialize_distributed(): + if dist.is_initialized(): + if dist.get_rank() == 0: + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) + else: + if int(os.environ["RANK"]) == 0: + print("Initializing Torch distributed.") + dist.init_process_group(backend="nccl") + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + global_world_size = dist.get_world_size() + torch.cuda.set_device(dist.get_rank() % local_world_size) + + _initialize_sequence_parallel() + + +# create_nccl_communicators() + + +def _initialize_sequence_parallel(sequence_parallel_size=None): + # Get world size and rank. Ensure some consistencies. + assert ( + sequence_parallel_size is None + ), "Multiple sequence parallel group not implemented." + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if sequence_parallel_size is None: + sequence_parallel_size = world_size + else: + assert world_size % sequence_parallel_size == 0 + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + + rank = torch.distributed.get_rank() + + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + global _SEQUENCE_PARALLEL_RANK + global _SEQUENCE_PARALLEL_SIZE + + assert ( + _SEQUENCE_PARALLEL_GROUP is None + ), "sequence parallel group is already initialized" + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + _SEQUENCE_PARALLEL_RANK = ranks.index(rank) + _SEQUENCE_PARALLEL_SIZE = len(ranks) + + if dist.get_rank() == 0: + print("************ Finish sequence pralell group Initialization. ***********") + # _set_global_memory_buffer() + + +def maybe_get_set_global_memory_buffer(q, k, v, m, l, o): + global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O + if _PEER_Q is None: + try: + if get_sequence_parallel_rank() == 0: + print("Initializing global memoery buffer.") + except: + print("Initializing global memoery buffer.") + _PEER_Q = [torch.empty_like(q) for _ in range(2)] + _PEER_K = [torch.empty_like(k) for _ in range(2)] + _PEER_V = [torch.empty_like(v) for _ in range(2)] + _PEER_M = [torch.empty_like(m) for _ in range(2)] + _PEER_L = [torch.empty_like(l) for _ in range(2)] + _PEER_O = [torch.empty_like(o) for _ in range(2)] + + return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O + + +def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do): + global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO + if _DELTA_DQ is None: + try: + if get_sequence_parallel_rank() == 0: + print("Initializing global memoery buffer for backward.") + except: + print("Initializing global memoery buffer for backward.") + _DELTA_DQ = [torch.empty_like(dq) for _ in range(2)] + _DELTA_DK = [torch.empty_like(dk) for _ in range(2)] + _DELTA_DV = [torch.empty_like(dv) for _ in range(2)] + _PEER_L = [torch.empty_like(L) for _ in range(2)] + + _DK_DELTA_FROM_PEER = torch.empty_like(dk) + _DV_DELTA_FROM_PEER = torch.empty_like(dv) + + # may already be initailized in the forward call. + # current forward and backward needs a transpose in q's format + _PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)] + _PEER_K_BWD = [torch.empty_like(k) for _ in range(2)] + _PEER_V_BWD = [torch.empty_like(v) for _ in range(2)] + _PEER_O_BWD = [torch.empty_like(o) for _ in range(2)] + + _PEER_DO = [torch.empty_like(do) for _ in range(2)] + + return ( + _DELTA_DQ, + _DELTA_DK, + _DELTA_DV, + _DK_DELTA_FROM_PEER, + _DV_DELTA_FROM_PEER, + _PEER_Q_BWD, + _PEER_L, + _PEER_K_BWD, + _PEER_V_BWD, + _PEER_O_BWD, + _PEER_DO, + ) + + +def reset_global_memory_buffer(): + global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO + _PEER_Q = None + _PEER_K = None + _PEER_V = None + _PEER_M = None + _PEER_L = None + _PEER_O = None + + _DELTA_DQ = None + _PEER_L = None + _DELTA_DK = None + _DELTA_DV = None + _DK_DELTA_FROM_PEER = None + _DV_DELTA_FROM_PEER = None + _PEER_DO = None + + +# Pytorch defers the creation of nccl communicators to the first P2P call, +# We manually create them so the first isend does not hang without an irecv. +# reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138 +# Only support even number of GPUs. +def create_nccl_communicators(): + seq_rank = get_sequence_parallel_rank() + seq_group = get_sequence_parallel_group() + + empty_tensor = torch.empty( + 1, + ).cuda() + empty_tensor_2 = torch.empty( + 1, + ).cuda() + if torch.distributed.get_rank() % 2 == 0: + # sender + op1 = P2POp( + op=isend, + tensor=torch.empty( + 1, + ).cuda(), + peer=seq_rank + 1, + group=seq_group, + ) + op2 = P2POp( + op=irecv, + tensor=torch.empty( + 1, + ).cuda(), + peer=seq_rank + 1, + group=seq_group, + ) + # req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) + dist.batch_isend_irecv([op1, op2]) + else: + # receiver + op1 = P2POp( + op=irecv, + tensor=torch.empty( + 1, + ).cuda(), + peer=seq_rank - 1, + group=seq_group, + ) + op2 = P2POp( + op=isend, + tensor=torch.empty( + 1, + ).cuda(), + peer=seq_rank - 1, + group=seq_group, + ) + # req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) + handles = dist.batch_isend_irecv([op1, op2]) + # req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group) + dist.all_reduce(empty_tensor, group=seq_group) + + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + # global _SEQUENCE_PARALLEL_GROUP + assert ( + _SEQUENCE_PARALLEL_GROUP is not None + ), "sequence parallel group is not initialized" + return _SEQUENCE_PARALLEL_GROUP + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_RANK + if _SEQUENCE_PARALLEL_RANK is not None: + return _SEQUENCE_PARALLEL_RANK + return torch.distributed.get_rank(group=get_sequence_parallel_group()) + + +def get_sequence_parallel_size(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_SIZE + if _SEQUENCE_PARALLEL_SIZE is not None: + return _SEQUENCE_PARALLEL_SIZE + return torch.distributed.get_world_size(group=get_sequence_parallel_group()) + + +def destroy_sequence_parallel(): + """Set the groups to none.""" + global _SEQUENCE_PARALLEL_GROUP + _SEQUENCE_PARALLEL_GROUP = None + + +# whether this is the last time the kernel being called +def is_last_time(time_step): + # e.g. on a 8-GPU setup: + # R=0: 0 + # R=1: 1 + # R=2: 2 + # R=3: 3 + # R=4: 4, 5, 6, 7 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank <= seq_world_size // 2: # no one helps these ranks + rank_finish_time = seq_rank + else: + rank_finish_time = seq_world_size // 2 + return rank_finish_time == time_step + + +# Whether the current time step is computing for local q +def is_compute_for_local_query(time_step): + # R=3,4,5,6,7: Yes + # R=0: 0 + # R=1: 0, 1 + # R=2: 0, 1, 2 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank >= min(seq_world_size // 2, time_step): + return True + return False + + +# Whether the current time step is idle +def is_idle(time_step): + # 0, 1, 2, 3: 4 + # 4, 5, 6, 7: No + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2: + return True + return False + + +# Whether the current time step needs to synchronize with a remote computed result +def is_sync_from_remote(time_step): + # R=0, 1, 2, 3, 4: No + # R=5: 4 + # R=6: 3, 4 + # R=7: 2, 3, 4 + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if seq_rank > max(seq_world_size // 2, seq_world_size - time_step): + return True + return False + + +def maybe_send_recv_fwd_qkvo( + q: torch.Tensor, + peer_q: torch.Tensor, + k: torch.Tensor, + peer_k: torch.Tensor, + v: torch.Tensor, + peer_v: torch.Tensor, + o_stats: list, # peer_o_stats: list, + time_step: int, + comm_mode, + debug=False, +) -> torch.Tensor: + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Handles for operations that actually need to be wait before going to the next iteration. + # For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler; + all_handles = [] + # KV logic: different than older version, every rank to send/recv its own kv, + # to balance communication. In a balanced communication, every step each rank + # should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when + # time step > 0, should send its own kv and send/recv qo. In the older version, + # rank 0 does not send its kv, and rely on a later rank to pass it, where the + # later rank has to (1) receive kv, send rank 0's kv and send/recv qo. + # Q (load balancing) logic: semantically, this will be "%" world size, so + # the same send/recv rank as KV. Note: Only support even number of machines. + # O (load balancing) logic: rank 0 sends result to rank 7 at time 1. + # It get delayed for one time step, and thus has different maybe_send/recv_rank. + # Use (time_step + 1) to easily convert to synchornize version. + maybe_send_rank = seq_rank + (time_step + 1) + maybe_recv_rank = seq_rank - (time_step + 1) + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + _debug_send = _fwd_send_volume + _debug_recv = _fwd_recv_volume + + if maybe_send_rank >= seq_world_size: + # send q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + # print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)") + # q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) + all_handles.append( + P2POp( + op=isend, + tensor=q, + peer=maybe_send_rank % seq_world_size, + group=seq_group, + ) + ) + if debug: + _fwd_send_volume += torch.numel(q) * q.element_size() + else: + # send kv + # print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)") + # kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) + # kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) + all_handles.append( + P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group) + ) + all_handles.append( + P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group) + ) + if debug: + _fwd_send_volume += torch.numel(k) * k.element_size() + _fwd_send_volume += torch.numel(v) * v.element_size() + + if maybe_recv_rank < 0: + # recv q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + # print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)") + # q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) + all_handles.append( + P2POp( + op=irecv, + tensor=peer_q, + peer=maybe_recv_rank % seq_world_size, + group=seq_group, + ) + ) + if debug: + _fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() + else: + # recv kv + # print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)") + # kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) + # kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) + all_handles.append( + P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group) + ) + all_handles.append( + P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group) + ) + if debug: + _fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() + _fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() + + maybe_send_rank_o = seq_rank - (time_step - 1) + maybe_recv_rank_o = seq_rank + (time_step - 1) + if maybe_send_rank_o < 0 and time_step > 1: + for t in o_stats: + # print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)") + # o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) + all_handles.append( + P2POp( + op=isend, + tensor=t, + peer=maybe_send_rank_o % seq_world_size, + group=seq_group, + ) + ) + if debug: + _fwd_send_volume += torch.numel(t) * t.element_size() + if maybe_recv_rank_o >= seq_world_size and time_step > 1: + for t in o_stats: + # print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)") + # o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) + all_handles.append( + P2POp( + op=irecv, + tensor=t, + peer=maybe_recv_rank_o % seq_world_size, + group=seq_group, + ) + ) + if debug: + _fwd_recv_volume += torch.numel(t) * t.element_size() + + # reqs = [] + + if debug: + if seq_rank in [0, 8]: + print( + f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB" + ) + # return reqs + all_reqs = launch_async_handles(all_handles, comm_mode) + return [all_reqs] + + +# delta: may be you are using it for your local compute or as a distributed buffer to send to others +# .. Sorry for the bad naming.. +def maybe_send_recv_bwd_qkvo( + dq_delta: torch.Tensor, + dk_delta: torch.Tensor, + dv_delta: torch.Tensor, + dk_delta_from_peer: torch.Tensor, + dv_delta_from_peer: torch.Tensor, + q: torch.Tensor, + peer_q: torch.Tensor, + L: torch.Tensor, + peer_L: torch.Tensor, + k: torch.Tensor, + peer_k: torch.Tensor, + v: torch.Tensor, + peer_v: torch.Tensor, + o: torch.Tensor, + peer_o: torch.Tensor, + do: torch.Tensor, + peer_do: torch.Tensor, + time_step: int, + comm_mode, + debug=False, +): + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + all_handles = [] + maybe_send_rank = seq_rank + (time_step + 1) + maybe_recv_rank = seq_rank - (time_step + 1) + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + + if maybe_send_rank >= seq_world_size: + # send q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + all_handles.append( + P2POp( + op=isend, + tensor=q, + peer=maybe_send_rank % seq_world_size, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=isend, + tensor=L, + peer=maybe_send_rank % seq_world_size, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=isend, + tensor=o, + peer=maybe_send_rank % seq_world_size, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=isend, + tensor=do, + peer=maybe_send_rank % seq_world_size, + group=seq_group, + ) + ) + if debug: + _bwd_send_volume += torch.numel(q) * q.element_size() + _bwd_send_volume += torch.numel(L) * L.element_size() + _bwd_send_volume += torch.numel(o) * o.element_size() + _bwd_send_volume += torch.numel(do) * do.element_size() + else: + # send kv + all_handles.append( + P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group) + ) + all_handles.append( + P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group) + ) + if debug: + _bwd_send_volume += torch.numel(k) * k.element_size() + _bwd_send_volume += torch.numel(v) * v.element_size() + + if maybe_recv_rank < 0: + # recv q, no one needs to do remote computation in the last time step + if time_step < (seq_world_size // 2 - 1): + all_handles.append( + P2POp( + op=irecv, + tensor=peer_q, + peer=maybe_recv_rank % seq_world_size, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=irecv, + tensor=peer_L, + peer=maybe_recv_rank % seq_world_size, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=irecv, + tensor=peer_o, + peer=maybe_recv_rank % seq_world_size, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=irecv, + tensor=peer_do, + peer=maybe_recv_rank % seq_world_size, + group=seq_group, + ) + ) + if debug: + _bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() + _bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size() + _bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size() + _bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size() + else: + # recv kv + all_handles.append( + P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group) + ) + all_handles.append( + P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group) + ) + if debug: + _bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() + _bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() + + # Whether I should update dq, dk and dv after waiting these requests + is_update_dq = False + is_update_dkv = False + + maybe_send_rank_dqkv = seq_rank - (time_step - 1) + maybe_recv_rank_dqkv = seq_rank + (time_step - 1) + + if time_step > 1: + if maybe_send_rank_dqkv < 0: + # print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}") + all_handles.append( + P2POp( + op=isend, + tensor=dq_delta, + peer=maybe_send_rank_dqkv % seq_world_size, + group=seq_group, + ) + ) + if debug: + _bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size() + else: + # print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}") + all_handles.append( + P2POp( + op=isend, + tensor=dk_delta, + peer=maybe_send_rank_dqkv, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=isend, + tensor=dv_delta, + peer=maybe_send_rank_dqkv, + group=seq_group, + ) + ) + if debug: + _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() + + if maybe_recv_rank_dqkv >= seq_world_size: + # print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}") + all_handles.append( + P2POp( + op=irecv, + tensor=dq_delta, + peer=maybe_recv_rank_dqkv % seq_world_size, + group=seq_group, + ) + ) + is_update_dq = True + if debug: + _bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size() + else: + # print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}") + all_handles.append( + P2POp( + op=irecv, + tensor=dk_delta_from_peer, + peer=maybe_recv_rank_dqkv, + group=seq_group, + ) + ) + all_handles.append( + P2POp( + op=irecv, + tensor=dv_delta_from_peer, + peer=maybe_recv_rank_dqkv, + group=seq_group, + ) + ) + is_update_dkv = True + if debug: + _bwd_recv_volume += ( + torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size() + ) + _bwd_recv_volume += ( + torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size() + ) + + # return [], is_update_dq, is_update_dkv + all_reqs = launch_async_handles(all_handles, comm_mode) + return [all_reqs], is_update_dq, is_update_dkv + + +def maybe_send_recv_bwd_last_dkv( + dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False +): + is_update_last_dkv = False + + seq_group = get_sequence_parallel_group() + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + if seq_world_size == 1: + return [], is_update_last_dkv + + all_handles = [] + + if debug: + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + + if time_step == seq_world_size // 2: + maybe_send_rank = seq_rank - time_step + maybe_recv_rank = seq_rank + time_step + + assert (maybe_send_rank >= 0) ^ ( + maybe_recv_rank < seq_world_size + ), "R={seq_rank} should be either sending or receiving dkv in the last time step." + + if maybe_send_rank >= 0: + # print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}") + all_handles.append( + P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group) + ) + all_handles.append( + P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group) + ) + if debug: + _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() + if maybe_recv_rank < seq_world_size: + # print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}") + all_handles.append( + P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group) + ) + all_handles.append( + P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group) + ) + if debug: + _bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size() + _bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size() + is_update_last_dkv = True + + # return [], is_update_last_dkv + all_reqs = launch_async_handles(all_handles, comm_mode) + + return [all_reqs], is_update_last_dkv + + +def print_and_reset_comm_stats(): + seq_rank = get_sequence_parallel_rank() + + global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume + _fwd_send_volume *= 1e-9 + _fwd_recv_volume *= 1e-9 + _bwd_send_volume *= 1e-9 + _bwd_recv_volume *= 1e-9 + + print( + f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB." + ) + _fwd_send_volume = 0 + _fwd_recv_volume = 0 + _bwd_send_volume = 0 + _bwd_recv_volume = 0 + + +def launch_async_handles(handles, comm_mode): + global _args + if comm_mode == "nocomm": + # print("skipping communication for ablation") + return [] + if len(handles) > 0: + return dist.batch_isend_irecv(handles) + return [] + + +def wait_async_handles(reqs): + if len(reqs) > 0: + for req in reqs: + for r in req: + r.wait() diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py b/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py new file mode 100644 index 000000000..e7df1d88c --- /dev/null +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py @@ -0,0 +1,1209 @@ +import argparse + +# from torch.profiler import profile, record_function, ProfilerActivity +import functools +import math +import os +import time + +import numpy as np +import pytest +import torch +import torch.distributed as dist +import triton +import triton.language as tl +from einops import rearrange +from torch.distributed import ReduceOp +from tqdm import tqdm + +try: + from flash_attn.flash_attn_interface import ( + _flash_attn_backward, + _flash_attn_forward, + ) +except: + pass + +from .async_communication import ( + get_sequence_parallel_rank, + get_sequence_parallel_size, + initialize_distributed, + is_compute_for_local_query, + is_idle, + is_last_time, + is_sync_from_remote, + launch_async_handles, + maybe_get_set_global_memory_buffer, + maybe_get_set_global_memory_buffer_bwd, + maybe_send_recv_bwd_last_dkv, + maybe_send_recv_bwd_qkvo, + maybe_send_recv_fwd_qkvo, + print_and_reset_comm_stats, + reset_global_memory_buffer, + wait_async_handles, +) + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def _rescale_kernel( + peer_m, + m, + peer_l, + l, + peer_o, + o, + L, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + o_offset = off_hz * stride_oh + peer_o_block_ptr = tl.make_block_ptr( + base=peer_o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + o_block_ptr = tl.make_block_ptr( + base=o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + peer_m_ptrs = peer_m + off_hz * N_CTX + offs_m + m_ptrs = m + off_hz * N_CTX + offs_m + peer_l_ptrs = peer_l + off_hz * N_CTX + offs_m + l_ptrs = l + off_hz * N_CTX + offs_m + + peer_m_i = tl.load(peer_m_ptrs) + peer_m_i = peer_m_i.to(tl.float32) + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + peer_l_i = tl.load(peer_l_ptrs) + peer_l_i = peer_l_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + + peer_acc = tl.load(peer_o_block_ptr) + peer_acc = peer_acc.to(tl.float32) + acc = tl.load(o_block_ptr) + acc = acc.to(tl.float32) + lo = 0 + hi = N_CTX + m_i_sync = tl.maximum(m_i, peer_m_i) + alpha = tl.math.exp2(m_i - m_i_sync) + peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug + + acc *= acc_scale[:, None] + peer_acc *= peer_acc_scale[:, None] + acc += peer_acc + l_i = l_i * acc_scale + peer_l_i * peer_acc_scale + # write back O, l, m + tl.store(m_ptrs, m_i_sync) + tl.store(l_ptrs, l_i) + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) + tl.store(o_block_ptr, acc.to(tl.bfloat16)) + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + m, + l, + O, + L, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l -> load from provided pointer + m_ptrs = m + off_hz * N_CTX + offs_m + l_ptrs = l + off_hz * N_CTX + offs_m + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + acc = tl.load(O_block_ptr) + acc = acc.to(tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(tl.bfloat16) + # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where( + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf") + ) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.bfloat16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back original l and m + tl.store(m_ptrs, m_i) + tl.store(l_ptrs, l_i) + # write back O, L + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) + tl.store(O_block_ptr, acc.to(tl.bfloat16)) + + +# for gqa/mqa to expand kv heads +def maybe_repeat_kv_fwd(nqh, kv): + bs, nkvh, slen, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) + return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) + + +def maybe_repeat_kv_bwd(nqh, kv): + bs, slen, nkvh, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) + return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) + + +# kv grad has shape bs, slen, nqh, hdim +def maybe_reduce_dkv(nkvh, dkv): + bs, slen, nqh, hdim = dkv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return dkv + dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) + return torch.sum(dkv_reshape, dim=3) + + +def _lightseq_forward(q, k, v, causal, sm_scale, comm_mode): + # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + # Why do I have to change it from 128 64 to 32 32? + BLOCK_M = 32 + BLOCK_N = 32 + + bsz, nh, seq_len, hdim = q.shape + + m = torch.full( + (bsz * nh, seq_len), + fill_value=-float("inf"), + device=q.device, + dtype=torch.float32, + ) + l = torch.zeros_like(m) + L = torch.zeros_like(m) + o = torch.zeros_like(q) + + grid = (triton.cdiv(seq_len, BLOCK_M), bsz * nh, 1) + num_warps = 4 if Lk <= 64 else 8 + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all buffers + peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer( + q, k, v, m, l, o + ) + + fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[ + grid + ]( + q, + k, + v, + sm_scale, + m, + l, + o, + L, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + IS_CAUSAL=IS_CAUSAL, + LAST_STEP=LAST_STEP, + num_warps=num_warps, + num_stages=4, + ) + + for time_step in range(seq_world_size // 2 + 1): + # This is important for cuda scheduler to execute nccl calls first. + torch.cuda.synchronize() + # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs = maybe_send_recv_fwd_qkvo( + q, + peer_q[buffer_idx_1], + k, + peer_k[buffer_idx_1], + v, + peer_v[buffer_idx_1], + [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], + time_step, + comm_mode, + ) + if comm_mode == "sync": + # if seq_rank == 0: + # print("Immediate wait for abalation") + wait_async_handles(reqs) + if is_compute_for_local_query(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} local compute") + if time_step == 0: + fwd_launch_helper( + q, + maybe_repeat_kv_fwd(q.shape[1], k), + maybe_repeat_kv_fwd(q.shape[1], v), + m, + l, + o, + L, + True, + is_last_time(time_step), + ) + else: + # if needs to sync from others, do not normalize here + fwd_launch_helper( + q, + maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), + maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), + m, + l, + o, + L, + False, + not is_sync_from_remote(time_step) and is_last_time(time_step), + ) + elif is_idle(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"t={time_step}: (Comp) R={seq_rank} helps other") + peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) + peer_l[buffer_idx_2] = torch.zeros_like(l) + peer_o[buffer_idx_2] = torch.zeros_like(o) + + # print(f"rank 3 q is: {peer_q[buffer_idx_2]}") + fwd_launch_helper( + peer_q[buffer_idx_2], + maybe_repeat_kv_fwd(q.shape[1], k), + maybe_repeat_kv_fwd(q.shape[1], v), + peer_m[buffer_idx_2], + peer_l[buffer_idx_2], + peer_o[buffer_idx_2], + None, + False, + False, + ) + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + # sync between statistics get from other ranks and the local ones + if is_sync_from_remote(time_step): + _rescale_kernel[grid]( + peer_m[buffer_idx_1], + m, + peer_l[buffer_idx_1], + l, + peer_o[buffer_idx_1], + o, + L, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + o.shape[0], + o.shape[1], + o.shape[2], + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + LAST_STEP=is_last_time(time_step), + num_warps=num_warps, + num_stages=4, + ) + return q, k, v, o, L + + +def _lightseq_backward(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine): + BLOCK = 128 + q, k, v, o, do = [ + rearrange(_x, "b h s d -> b s h d").contiguous() for _x in [q, k, v, o, do] + ] + L = rearrange(L, "(b h) s -> b h s", b=q.shape[0]) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + # maybe gqa + nqh = q.shape[2] + nkvh = k.shape[2] + is_gqa = nqh > nkvh + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all backward buffers + ( + dq_delta, + dk_delta, + dv_delta, + dk_delta_from_peer, + dv_delta_from_peer, + peer_q, + peer_L, + peer_k, + peer_v, + peer_o, + peer_do, + ) = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) + + for time_step in range(0, get_sequence_parallel_size() // 2 + 1): + torch.cuda.synchronize() + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo( + dq_delta[buffer_idx_1], + dk_delta[buffer_idx_1], + dv_delta[buffer_idx_1], + dk_delta_from_peer, + dv_delta_from_peer, + q, + peer_q[buffer_idx_1], + L, + peer_L[buffer_idx_1], + k, + peer_k[buffer_idx_1], + v, + peer_v[buffer_idx_1], + o, + peer_o[buffer_idx_1], + do, + peer_do[buffer_idx_1], + time_step, + comm_mode, + ) + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) Immediate wait for abalation") + wait_async_handles(reqs) + + if is_compute_for_local_query(time_step): + if time_step == 0: + if backward_engine == "flash": + _flash_attn_backward( + do, + q, + k, + v, + o, + L, + dq, + dk, + dv, + 0.0, + sm_scale, + True, + (-1, -1), + None, + False, + ) + else: + inp = Inputs( + query=q, + key=maybe_repeat_kv_bwd(q.shape[2], k), + value=maybe_repeat_kv_bwd(q.shape[2], v), + attn_bias=xformers.ops.LowerTriangularMask(), + p=0, + scale=sm_scale, + ) + op_ctx = Context(lse=L, out=o, rng_state=None) + # Let xformers dispatch the correct backend + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=do, op=None + ) + dq = grads.dq + dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv( + nkvh, grads.dv + ) + else: + if backward_engine == "flash": + _flash_attn_backward( + do, + q, + peer_k[buffer_idx_2], + peer_v[buffer_idx_2], + o, + L, + dq_delta[buffer_idx_2], + dk_delta[buffer_idx_2], + dv_delta[buffer_idx_2], + 0.0, + sm_scale, + False, + (-1, -1), + None, + False, + ) + else: + inp = Inputs( + query=q, + key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), + value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), + attn_bias=None, + p=0, + scale=sm_scale, + ) + op_ctx = Context(lse=L, out=o, rng_state=None) + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=do, op=None + ) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( + nkvh, grads.dk + ), maybe_reduce_dkv(nkvh, grads.dv) + dq += dq_delta[buffer_idx_2] + elif is_idle(time_step): + pass + else: + if backward_engine == "flash": + _flash_attn_backward( + peer_do[buffer_idx_2], + peer_q[buffer_idx_2], + k, + v, + peer_o[buffer_idx_2], + peer_L[buffer_idx_2], + dq_delta[buffer_idx_2], + dk_delta[buffer_idx_2], + dv_delta[buffer_idx_2], + 0.0, + sm_scale, + False, + (-1, -1), + None, + False, + ) + else: + inp = Inputs( + query=peer_q[buffer_idx_2], + key=maybe_repeat_kv_bwd(q.shape[2], k), + value=maybe_repeat_kv_bwd(q.shape[2], v), + attn_bias=None, + p=0, + scale=sm_scale, + ) + op_ctx = Context( + lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None + ) + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None + ) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( + nkvh, grads.dk + ), maybe_reduce_dkv(nkvh, grads.dv) + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + + # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. + reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv( + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode + ) + + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) dkv Immediate wait for abalation") + wait_async_handles(reqs) + # apply dq_delta, dk_delta and dv_delta from remote + if is_update_dq: + dq += dq_delta[buffer_idx_1] + if is_update_dkv: + dk += dk_delta_from_peer + dv += dv_delta_from_peer + + if comm_mode == "lightseq": + wait_async_handles(reqs) + # apply dk_delta and dv_delta to sender + if is_update_last_dkv: + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + dq, dk, dv = [rearrange(_x, "b h s d -> b s h d") for _x in [dq, dk, dv]] + return dq, dk, dv + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + try: + global args + comm_mode = args.comm_mode + backward_engine = args.backward_engine + except: + comm_mode = "lightseq" + backward_engine = "flash" + + q, k, v, o, L = _lightseq_forward(q, k, v, causal, sm_scale, comm_mode) + + ctx.save_for_backward(q, k, v, o, L) + ctx.sm_scale = sm_scale + ctx.comm_mode = comm_mode + ctx.backward_engine = backward_engine + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, L = ctx.saved_tensors + sm_scale = ctx.sm_scale + + dq, dk, dv = _lightseq_backward( + do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine + ) + return dq, dk, dv, None, None + + +attention = _attention.apply + + +# @pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(20) + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + + a, b, c, d = q.size() + real_q = ( + q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_k = ( + k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_v = ( + v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_do = ( + dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + + tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose( + ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_out, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose( + ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dq, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails backward dq" + assert torch.allclose( + ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dk, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dk" # f" {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose( + ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dv, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + + +def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(177) + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # torch reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) + ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) + p = torch.matmul(q, ref_k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, ref_v) + ref_out.backward(dout) + ref_dv, v.grad = ref_v.grad.clone(), None + ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1, 2))).transpose(1, 2) + ref_dk, k.grad = ref_k.grad.clone(), None + ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1, 2))).transpose(1, 2) + ref_dq, q.grad = q.grad.clone(), None + + # flash reference + from flash_attn import flash_attn_func, flash_attn_qkvpacked_func + + flash_q = q.transpose(1, 2).clone().detach().requires_grad_(True) + flash_k = k.transpose(1, 2).clone().detach().requires_grad_(True) + flash_v = v.transpose(1, 2).clone().detach().requires_grad_(True) + flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) + flash_ref_out.backward(dout.transpose(1, 2)) + flash_ref_out = flash_ref_out.transpose(1, 2) + flash_ref_dv, v.grad = flash_v.grad.clone(), None + flash_ref_dv = flash_ref_dv.transpose(1, 2) + flash_ref_dk, k.grad = flash_k.grad.clone(), None + flash_ref_dk = flash_ref_dk.transpose(1, 2) + flash_ref_dq, q.grad = flash_q.grad.clone(), None + flash_ref_dq = flash_ref_dq.transpose(1, 2) + + # triton implementation + + a, b, c, d = q.size() + real_q = ( + q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_k = ( + k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, KVH, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_v = ( + v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, KVH, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_do = ( + dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + + tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose( + flash_ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_out, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails forward against flash" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose( + flash_ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dq, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails backward dq against flash" + # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose( + flash_ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dk, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose( + flash_ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dv, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward against flash") + + assert torch.allclose( + ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_out, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + assert torch.allclose( + ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dq, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails backward dq" + # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose( + ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dk, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose( + ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dv, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + + +# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +try: + from flash_attn.flash_attn_interface import ( + flash_attn_qkvpacked_func as flash_attn_func, + ) + + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None +HAS_FLASH = None +ONLY_FLASH = False + +# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[ + 2**i for i in range(18, 19) + ], # [ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], + line_arg="provider", + line_vals=["triton"] + if not ONLY_FLASH + else [] + (["flash"] if HAS_FLASH else []), + line_names=["Triton"] + if not ONLY_FLASH + else [] + ([f"Flash-{FLASH_VER}"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "D_HEAD": D_HEAD, + "dtype": torch.bfloat16, + "mode": mode, + "causal": causal, + }, + ) + for mode in ["all"] + for causal in [True] +] + + +# @triton.testing.perf_report(configs) +def bench_flash_attention( + BATCH, + H, + KVH, + N_CTX, + D_HEAD, + causal, + mode, + provider, + args, + dtype=torch.bfloat16, + device="cuda", +): + assert mode == "all" # mode in ['fwd', 'bwd'] + n_warmup = 10 + n_repeat = 10 + cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if provider == "triton": + q = torch.randn( + (BATCH, H, N_CTX // seq_world_size, D_HEAD), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + k = torch.randn( + (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + v = torch.randn( + (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + if seq_rank == 0: + print(f"Benchmarking per GPU qkv shape: {q.shape}") + sm_scale = 1.3 + fwd_fn = lambda: attention(q, k, v, causal, sm_scale) + if provider == "flash": + qkv = torch.randn( + (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True + ) + if FLASH_VER == 1: + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0.0, N_CTX, causal=causal) + elif FLASH_VER == 2: + fwd_fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f"unknown {FLASH_VER = }") + + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size + attn_flops = 2 * flops_per_matmul + + assert causal + if causal: + attn_flops *= 0.5 + fwd_flops = attn_flops + bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) + + o = fwd_fn() + do = torch.randn_like(o) + bwd_fn = lambda: o.backward(do, retain_graph=True) + + def run_benchmark(fn): + time_list = [] + for _ in tqdm(range(n_warmup)): + cache.zero_() + fn() + torch.cuda.synchronize() + if args.debug: + print_and_reset_comm_stats() + for i in tqdm(range(n_repeat)): + cache.zero_() + torch.cuda.synchronize() + time_s = time.time() + fn() + torch.cuda.synchronize() + time_e = time.time() + time_list.append((time_e - time_s) * 1000.0) + if args.debug: + print_and_reset_comm_stats() + return np.asarray(time_list) + + fwd_time_arr = run_benchmark(fwd_fn) + bwd_time_arr = run_benchmark(bwd_fn) + + fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 + print( + f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n" + ) + + bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 + print( + f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n" + ) + + # total + total_time_arr = fwd_time_arr + bwd_time_arr + total_flops = fwd_flops + bwd_flops + total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 + print( + f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n" + ) + + # return total_flops_ps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--comm-mode", type=str, default="lightseq") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--run-mode", type=str, default="benchmark") + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--n_heads", type=int, default=32) + parser.add_argument("--n_kvheads", type=int, default=32) + parser.add_argument("--d_head", type=int, default=128) + parser.add_argument("--start_ctx", type=int, default=12) + parser.add_argument("--end_ctx", type=int, default=18) + parser.add_argument("--forward_engine", type=str, default="triton") + parser.add_argument("--backward_engine", type=str, default="flash") + + global args + args = parser.parse_args() + initialize_distributed() + + assert args.forward_engine == "triton", "Only triton forward is implmented." + assert args.backward_engine in [ + "flash", + "xformers", + ], "Only flash or xformers backward is implemented." + + if args.backward_engine == "flash": + from flash_attn.flash_attn_interface import ( + _flash_attn_backward, + _flash_attn_forward, + ) + else: + try: + import xformers.ops + from xformers.ops.fmha import ( + _memory_efficient_attention_backward, + cutlass, + flash, + ) + from xformers.ops.fmha.common import Context, Inputs + except ImportError: + print("xformers not found! Please install it before trying to use it.") + + if args.run_mode == "benchmark": + for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: + bench_flash_attention( + args.bs, + args.n_heads, + args.n_kvheads, + N_CTX, + args.d_head, + True, + "all", + "triton", + args, + ) # .run(save_path='.', print_data=True) + reset_global_memory_buffer() + else: + assert args.run_mode == "test" + for N_CTX in [2048, 4096]: + test_op(1, 16, N_CTX, 128, True) + # test_gqa(1, 16, 8, N_CTX, 128, True) + reset_global_memory_buffer() diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py b/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py new file mode 100644 index 000000000..2f1916903 --- /dev/null +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py @@ -0,0 +1,1273 @@ +import argparse +import math +import os +import time + +import numpy as np +import pytest +import torch +import torch.distributed as dist +import triton +import triton.language as tl +from einops import rearrange +from torch.distributed import ReduceOp +from tqdm import tqdm + +# from torch.profiler import profile, record_function, ProfilerActivity + + +try: + from flash_attn.flash_attn_interface import _flash_attn_varlen_backward +except: + pass + +from .async_communication import ( + get_sequence_parallel_rank, + get_sequence_parallel_size, + initialize_distributed, + is_compute_for_local_query, + is_idle, + is_last_time, + is_sync_from_remote, + launch_async_handles, + maybe_get_set_global_memory_buffer, + maybe_get_set_global_memory_buffer_bwd, + maybe_send_recv_bwd_last_dkv, + maybe_send_recv_bwd_qkvo, + maybe_send_recv_fwd_qkvo, + print_and_reset_comm_stats, + reset_global_memory_buffer, + wait_async_handles, +) + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def _rescale_kernel( + peer_m, + m, + peer_l, + l, + peer_o, + o, + L, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + seqlen_q_rounded, + seqlen_peer_q_rounded, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + o_offset = off_hz * stride_oh + peer_o_block_ptr = tl.make_block_ptr( + base=peer_o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + o_block_ptr = tl.make_block_ptr( + base=o + o_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + peer_m_ptrs = peer_m + off_hz * seqlen_peer_q_rounded + offs_m + m_ptrs = m + off_hz * seqlen_q_rounded + offs_m + peer_l_ptrs = peer_l + off_hz * seqlen_peer_q_rounded + offs_m + l_ptrs = l + off_hz * seqlen_q_rounded + offs_m + + peer_m_i = tl.load(peer_m_ptrs) + peer_m_i = peer_m_i.to(tl.float32) + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + peer_l_i = tl.load(peer_l_ptrs) + peer_l_i = peer_l_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + + peer_acc = tl.load( + peer_o_block_ptr + ) # , boundary_check=(0, 1), padding_option='zero') + peer_acc = peer_acc.to(tl.float32) + acc = tl.load(o_block_ptr) # , boundary_check=(0, 1), padding_option='zero') + acc = acc.to(tl.float32) + lo = 0 + hi = N_CTX + m_i_sync = tl.maximum(m_i, peer_m_i) + alpha = tl.math.exp2(m_i - m_i_sync) + peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug + + acc *= acc_scale[:, None] + peer_acc *= peer_acc_scale[:, None] + acc += peer_acc + l_i = l_i * acc_scale + peer_l_i * peer_acc_scale + # write back O, l, m + tl.store(m_ptrs, m_i_sync) + tl.store(l_ptrs, l_i) + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * N_CTX + offs_m + tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) + tl.store(o_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + m, + l, + O, + L, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + seqlen_q_rounded, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + LAST_STEP: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l -> load from provided pointer + # (TODO): Why float32? + m_ptrs = m + off_hz * seqlen_q_rounded + offs_m + l_ptrs = l + off_hz * seqlen_q_rounded + offs_m + m_i = tl.load(m_ptrs) + m_i = m_i.to(tl.float32) + l_i = tl.load(l_ptrs) + l_i = l_i.to(tl.float32) + acc = tl.load(O_block_ptr) + acc = acc.to(tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + q = (q * qk_scale).to(tl.bfloat16) + # loop over k, v and update accumulator + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero") + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where( + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf") + ) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.bfloat16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back original l and m + tl.store(m_ptrs, m_i) + tl.store(l_ptrs, l_i) + # write back O, L + if LAST_STEP: + acc = acc / l_i[:, None] + L_ptrs = L + off_hz * seqlen_q_rounded + offs_m + tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) + tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) + + +# for gqa/mqa to expand kv heads +def maybe_repeat_kv_fwd(nqh, kv): + bs, nkvh, slen, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) + return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) + + +def maybe_repeat_kv_bwd(nqh, kv): + bs, slen, nkvh, hdim = kv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return kv + kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) + return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) + + +# kv grad has shape bs, slen, nqh, hdim +def maybe_reduce_dkv(nkvh, dkv): + bs, slen, nqh, hdim = dkv.shape + n_rep = nqh // nkvh + if n_rep == 1: + return dkv + # print("*"*100, dkv.shape, bs, slen, nkvh, n_rep, hdim) + dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) + # print("-"*100, dkv_reshape.shape, bs, slen, nkvh, n_rep, hdim) + return torch.sum(dkv_reshape, dim=3) + + +def _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode): + # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + # assert Lq == Lk and Lk == Lv + # assert Lk in {16, 32, 64, 128} + BLOCK_M = 128 + BLOCK_N = 64 + + bsz, nh, unpadded_seq_len, hdim = q.shape + cu_seq_lens = torch.arange( + 0, + (bsz + 1) * unpadded_seq_len, + unpadded_seq_len, + dtype=torch.int32, + device=q.device, + ) + max_seqlen = unpadded_seq_len + seqlen_q_rounded = math.ceil(q.shape[2] / BLOCK_M) * BLOCK_M + + m = torch.full( + (bsz * nh, seqlen_q_rounded), + fill_value=-float("inf"), + device=q.device, + dtype=torch.float32, + ) + l = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) + L = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.zeros_like(q) + + grid = (triton.cdiv(q.shape[2], BLOCK_M), bsz * nh, 1) + num_warps = 4 if Lk <= 64 else 8 + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all buffers + peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer( + q, k, v, m, l, o + ) + + fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[ + grid + ]( + q, + k, + v, + sm_scale, + m, + l, + o, + L, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + seqlen_q_rounded, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + IS_CAUSAL=IS_CAUSAL, + LAST_STEP=LAST_STEP, + num_warps=num_warps, + num_stages=4, + ) + + for time_step in range(seq_world_size // 2 + 1): + # This is important for cuda scheduler to execute nccl calls first. + torch.cuda.synchronize() + # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs = maybe_send_recv_fwd_qkvo( + q, + peer_q[buffer_idx_1], + k, + peer_k[buffer_idx_1], + v, + peer_v[buffer_idx_1], + [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], + time_step, + comm_mode, + ) + if comm_mode == "sync": + # if seq_rank == 0: + # print("Immediate wait for abalation") + wait_async_handles(reqs) + if is_compute_for_local_query(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} local compute") + if time_step == 0: + fwd_launch_helper( + q, + maybe_repeat_kv_fwd(q.shape[1], k), + maybe_repeat_kv_fwd(q.shape[1], v), + m, + l, + o, + L, + True, + is_last_time(time_step), + ) + else: + # if needs to sync from others, do not normalize here + fwd_launch_helper( + q, + maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), + maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), + m, + l, + o, + L, + False, + not is_sync_from_remote(time_step) and is_last_time(time_step), + ) + elif is_idle(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"t={time_step}: (Comp) R={seq_rank} helps other") + peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) + peer_l[buffer_idx_2] = torch.zeros_like(l) + peer_o[buffer_idx_2] = torch.zeros_like(o) + + # print(f"rank 3 q is: {peer_q[buffer_idx_2]}") + fwd_launch_helper( + peer_q[buffer_idx_2], + maybe_repeat_kv_fwd(q.shape[1], k), + maybe_repeat_kv_fwd(q.shape[1], v), + peer_m[buffer_idx_2], + peer_l[buffer_idx_2], + peer_o[buffer_idx_2], + None, + False, + False, + ) + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + # sync between statistics get from other ranks and the local ones + if is_sync_from_remote(time_step): + # print(f"t={time_step}: (Comp) R={seq_rank} sync with other - last time: {is_last_time(time_step)}") + seqlen_peer_q_rounded = peer_l[buffer_idx_1].shape[-1] + _rescale_kernel[grid]( + peer_m[buffer_idx_1], + m, + peer_l[buffer_idx_1], + l, + peer_o[buffer_idx_1], + o, + L, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + o.shape[0], + o.shape[1], + o.shape[2], + seqlen_q_rounded, + seqlen_peer_q_rounded, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + LAST_STEP=is_last_time(time_step), + num_warps=num_warps, + num_stages=4, + ) + return q, k, v, o, L, cu_seq_lens, max_seqlen + + +def _lightseq_backward_varlen( + do, q, k, v, o, L, sm_scale, comm_mode, backward_engine, cu_seq_lens, max_seqlen +): + BLOCK = 128 + L = rearrange(L[:, :max_seqlen].contiguous(), "(b h) s -> b h s", b=q.shape[0]) + q, k, v, o, do = [ + rearrange(_x, "b h s d -> (b s) h d").contiguous() for _x in [q, k, v, o, do] + ] + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + # maybe gqa + nqh = q.shape[1] + nkvh = k.shape[1] + is_gqa = nqh > nkvh + + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + + # Initialize all backward buffers + ( + dq_delta, + dk_delta, + dv_delta, + dk_delta_from_peer, + dv_delta_from_peer, + peer_q, + peer_L, + peer_k, + peer_v, + peer_o, + peer_do, + ) = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) + + for time_step in range(0, get_sequence_parallel_size() // 2 + 1): + torch.cuda.synchronize() + buffer_idx_1 = time_step % 2 + buffer_idx_2 = (time_step - 1) % 2 + + reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo( + dq_delta[buffer_idx_1], + dk_delta[buffer_idx_1], + dv_delta[buffer_idx_1], + dk_delta_from_peer, + dv_delta_from_peer, + q, + peer_q[buffer_idx_1], + L, + peer_L[buffer_idx_1], + k, + peer_k[buffer_idx_1], + v, + peer_v[buffer_idx_1], + o, + peer_o[buffer_idx_1], + do, + peer_do[buffer_idx_1], + time_step, + comm_mode, + ) + if comm_mode == "sync": + wait_async_handles(reqs) + + if is_compute_for_local_query(time_step): + if time_step == 0: + assert ( + backward_engine == "flash" + ), "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward( + do, + q, + k, + v, + o, + L, + dq, + dk, + dv, + cu_seq_lens, + cu_seq_lens, + max_seqlen, + max_seqlen, + 0.0, + sm_scale, + True, + None, + ) + else: + inp = Inputs( + query=q, + key=maybe_repeat_kv_bwd(q.shape[2], k), + value=maybe_repeat_kv_bwd(q.shape[2], v), + attn_bias=xformers.ops.LowerTriangularMask(), + p=0, + scale=sm_scale, + ) + op_ctx = Context(lse=L, out=o, rng_state=None) + # Let xformers dispatch the correct backend + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=do, op=None + ) + dq = grads.dq + dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv( + nkvh, grads.dv + ) + else: + assert ( + backward_engine == "flash" + ), "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward( + do, + q, + peer_k[buffer_idx_2], + peer_v[buffer_idx_2], + o, + L, + dq_delta[buffer_idx_2], + dk_delta[buffer_idx_2], + dv_delta[buffer_idx_2], + cu_seq_lens, + cu_seq_lens, + max_seqlen, + max_seqlen, + 0.0, + sm_scale, + False, + None, + ) + else: + inp = Inputs( + query=q, + key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), + value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), + attn_bias=None, + p=0, + scale=sm_scale, + ) + op_ctx = Context(lse=L, out=o, rng_state=None) + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=do, op=None + ) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( + nkvh, grads.dk + ), maybe_reduce_dkv(nkvh, grads.dv) + dq += dq_delta[buffer_idx_2] + elif is_idle(time_step): + # print(f"BWD t={time_step}: (Comp) R={seq_rank} idle") + pass + else: + # print(f"BWD t={time_step}: (Comp) R={seq_rank} helps other") + assert ( + backward_engine == "flash" + ), "We haven't supportted varlen feature in xformer" + if backward_engine == "flash": + _flash_attn_varlen_backward( + peer_do[buffer_idx_2], + peer_q[buffer_idx_2], + k, + v, + peer_o[buffer_idx_2], + peer_L[buffer_idx_2], + dq_delta[buffer_idx_2], + dk_delta[buffer_idx_2], + dv_delta[buffer_idx_2], + cu_seq_lens, + cu_seq_lens, + max_seqlen, + max_seqlen, + 0.0, + sm_scale, + False, + None, + ) + else: + inp = Inputs( + query=peer_q[buffer_idx_2], + key=maybe_repeat_kv_bwd(q.shape[2], k), + value=maybe_repeat_kv_bwd(q.shape[2], v), + attn_bias=None, + p=0, + scale=sm_scale, + ) + op_ctx = Context( + lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None + ) + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None + ) + dq_delta[buffer_idx_2] = grads.dq + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( + nkvh, grads.dk + ), maybe_reduce_dkv(nkvh, grads.dv) + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + if comm_mode == "lightseq": + # Make sure tensors for next steps are ready + wait_async_handles(reqs) + + # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. + reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv( + dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode + ) + + if comm_mode == "sync": + # if seq_rank == 0: + # print("(bwd) dkv Immediate wait for abalation") + wait_async_handles(reqs) + # apply dq_delta, dk_delta and dv_delta from remote + if is_update_dq: + dq += dq_delta[buffer_idx_1] + if is_update_dkv: + dk += dk_delta_from_peer + dv += dv_delta_from_peer + + if comm_mode == "lightseq": + wait_async_handles(reqs) + # apply dk_delta and dv_delta to sender + if is_update_last_dkv: + dk += dk_delta[buffer_idx_2] + dv += dv_delta[buffer_idx_2] + + dq, dk, dv = [ + rearrange(_x, "(b s) h d -> b h s d", s=max_seqlen) for _x in [dq, dk, dv] + ] + return dq, dk, dv + + +class _attention_varlen(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale): + try: + global args + comm_mode = args.comm_mode + backward_engine = args.backward_engine + except: + comm_mode = "lightseq" + backward_engine = "flash" + + q, k, v, o, L, cu_seq_lens, max_seqlen = _lightseq_forward_varlen( + q, k, v, causal, sm_scale, comm_mode + ) + + ctx.save_for_backward(q, k, v, o, L, cu_seq_lens) + ctx.max_seqlen = max_seqlen + ctx.sm_scale = sm_scale + ctx.comm_mode = comm_mode + ctx.backward_engine = backward_engine + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, L, cu_seq_lens = ctx.saved_tensors + sm_scale = ctx.sm_scale + max_seqlen = ctx.max_seqlen + + dq, dk, dv = _lightseq_backward_varlen( + do, + q, + k, + v, + o, + L, + sm_scale, + ctx.comm_mode, + ctx.backward_engine, + cu_seq_lens, + max_seqlen, + ) + return dq, dk, dv, None, None + + +dist_attn_varlen = _attention_varlen.apply + + +# @pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(20) + rank = dist.get_rank() + world_size = dist.get_world_size() + + PAD = world_size * 256 + seq_per_rank = (N_CTX - PAD) // world_size + q = ( + torch.empty((Z, H, N_CTX - PAD, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX - PAD, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX - PAD, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + + # DEBUG: mask out + # mask = torch.zeros(Z, H, seq_per_rank * (world_size - 1), D_HEAD).cuda() + # mask_2 = torch.ones(Z, H, seq_per_rank, D_HEAD).cuda() + # mask = torch.cat((mask, mask_2), dim=-2).to(dtype) + # q = mask * q + # k = mask * k + # v = mask * v + + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX - PAD, N_CTX - PAD), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + + a, b, c, d = q.size() + real_q = ( + q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_k = ( + k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_v = ( + v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_do = ( + dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + + tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose( + ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_out, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose( + ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dq, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dq" # {ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dq} {torch.max(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dq)} rank {rank} fails backward dk" + assert torch.allclose( + ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dk, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dk" # {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose( + ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dv, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dv" # {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + + +# TODO(High Priority): Investigate why rank 0 tends to have larger numerical difference. +def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): + torch.manual_seed(177) + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + seq_per_rank = N_CTX // world_size + + sm_scale = 0.5 + dout = torch.randn_like(q) + # torch reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) + ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) + # print(q.shape, ref_k.shape, k.shape) + p = torch.matmul(q, ref_k.transpose(2, 3)) * sm_scale + assert causal + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + ref_out = torch.matmul(p, ref_v) + ref_out.backward(dout) + ref_dv, v.grad = ref_v.grad.clone(), None + # print("Before reduce", ref_dv.shape) + ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1, 2))).transpose(1, 2) + # print("After reduce", ref_dv.shape) + ref_dk, k.grad = ref_k.grad.clone(), None + ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1, 2))).transpose(1, 2) + ref_dq, q.grad = q.grad.clone(), None + + # flash reference + from flash_attn import flash_attn_func, flash_attn_qkvpacked_func + + flash_q = q.transpose(1, 2).clone().detach().requires_grad_(True) + flash_k = k.transpose(1, 2).clone().detach().requires_grad_(True) + flash_v = v.transpose(1, 2).clone().detach().requires_grad_(True) + flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) + flash_ref_out.backward(dout.transpose(1, 2)) + flash_ref_out = flash_ref_out.transpose(1, 2) + flash_ref_dv, v.grad = flash_v.grad.clone(), None + flash_ref_dv = flash_ref_dv.transpose(1, 2) + flash_ref_dk, k.grad = flash_k.grad.clone(), None + flash_ref_dk = flash_ref_dk.transpose(1, 2) + flash_ref_dq, q.grad = flash_q.grad.clone(), None + flash_ref_dq = flash_ref_dq.transpose(1, 2) + + # triton implementation + + a, b, c, d = q.size() + real_q = ( + q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_k = ( + k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, KVH, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_v = ( + v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, KVH, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + real_do = ( + dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] + .view(a, b, -1, d) + .contiguous() + .clone() + .detach() + .requires_grad_(True) + ) + + tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() + + # compare + assert torch.allclose( + flash_ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_out, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails forward against flash" + print(f" *** rank {rank} passes forward") + tri_out.backward(real_do) + tri_dv, real_v.grad = real_v.grad.clone(), None + tri_dk, real_k.grad = real_k.grad.clone(), None + tri_dq, real_q.grad = real_q.grad.clone(), None + assert torch.allclose( + flash_ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dq, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails backward dq against flash" + # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose( + flash_ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dk, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose( + flash_ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dv, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward against flash") + + assert torch.allclose( + ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_out, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails forward" + print(f" *** rank {rank} passes forward") + assert torch.allclose( + ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dq, + atol=1e-2, + rtol=0, + ), f" rank {rank} fails backward dq" + # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) + assert torch.allclose( + ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dk, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" + assert torch.allclose( + ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], + tri_dv, + atol=1e-2, + rtol=0, + ), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" + print(f"rank {rank} passes backward") + + +# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +try: + from flash_attn.flash_attn_interface import ( + flash_attn_qkvpacked_func as flash_attn_func, + ) + + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None +HAS_FLASH = None +ONLY_FLASH = False + +# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[ + 2**i for i in range(18, 19) + ], # [ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], + line_arg="provider", + line_vals=["triton"] + if not ONLY_FLASH + else [] + (["flash"] if HAS_FLASH else []), + line_names=["Triton"] + if not ONLY_FLASH + else [] + ([f"Flash-{FLASH_VER}"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "D_HEAD": D_HEAD, + "dtype": torch.bfloat16, + "mode": mode, + "causal": causal, + }, + ) + for mode in ["all"] + for causal in [True] +] + + +# @triton.testing.perf_report(configs) +def bench_flash_attention( + BATCH, + H, + KVH, + N_CTX, + D_HEAD, + causal, + mode, + provider, + args, + dtype=torch.bfloat16, + device="cuda", +): + assert mode == "all" # mode in ['fwd', 'bwd'] + n_warmup = 10 + n_repeat = 10 + cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + seq_rank = get_sequence_parallel_rank() + seq_world_size = get_sequence_parallel_size() + if provider == "triton": + q = torch.randn( + (BATCH, H, N_CTX // seq_world_size, D_HEAD), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + k = torch.randn( + (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + v = torch.randn( + (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + if seq_rank == 0: + print(f"Benchmarking per GPU qkv shape: {q.shape}") + sm_scale = 1.3 + fwd_fn = lambda: dist_attn_varlen(q, k, v, causal, sm_scale) + if provider == "flash": + qkv = torch.randn( + (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True + ) + if FLASH_VER == 1: + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) + fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0.0, N_CTX, causal=causal) + elif FLASH_VER == 2: + fwd_fn = lambda: flash_attn_func(qkv, causal=causal) + else: + raise ValueError(f"unknown {FLASH_VER = }") + + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size + attn_flops = 2 * flops_per_matmul + + assert causal + if causal: + attn_flops *= 0.5 + fwd_flops = attn_flops + bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) + + o = fwd_fn() + do = torch.randn_like(o) + bwd_fn = lambda: o.backward(do, retain_graph=True) + + def run_benchmark(fn): + time_list = [] + for _ in tqdm(range(n_warmup)): + cache.zero_() + fn() + torch.cuda.synchronize() + if args.debug: + print_and_reset_comm_stats() + for i in tqdm(range(n_repeat)): + cache.zero_() + torch.cuda.synchronize() + time_s = time.time() + fn() + torch.cuda.synchronize() + time_e = time.time() + time_list.append((time_e - time_s) * 1000.0) + if args.debug: + print_and_reset_comm_stats() + return np.asarray(time_list) + + fwd_time_arr = run_benchmark(fwd_fn) + bwd_time_arr = run_benchmark(bwd_fn) + + fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 + print( + f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n" + ) + + bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 + print( + f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n" + ) + + # total + total_time_arr = fwd_time_arr + bwd_time_arr + total_flops = fwd_flops + bwd_flops + total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 + print( + f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n" + ) + + # return total_flops_ps + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--comm-mode", type=str, default="lightseq") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--run-mode", type=str, default="test") + parser.add_argument("--bs", type=int, default=1) + parser.add_argument("--n_heads", type=int, default=32) + parser.add_argument("--n_kvheads", type=int, default=32) + parser.add_argument("--d_head", type=int, default=128) + parser.add_argument("--start_ctx", type=int, default=12) + parser.add_argument("--end_ctx", type=int, default=18) + parser.add_argument("--forward_engine", type=str, default="triton") + parser.add_argument("--backward_engine", type=str, default="flash") + + global args + args = parser.parse_args() + initialize_distributed() + + assert args.forward_engine == "triton", "Only triton forward is implmented." + assert args.backward_engine in [ + "flash", + "xformers", + ], "Only flash or xformers backward is implemented." + + if args.backward_engine == "flash": + from flash_attn.flash_attn_interface import ( + _flash_attn_backward, + _flash_attn_forward, + ) + else: + try: + import xformers.ops + from xformers.ops.fmha import ( + _memory_efficient_attention_backward, + cutlass, + flash, + ) + from xformers.ops.fmha.common import Context, Inputs + except ImportError: + print("xformers not found! Please install it before trying to use it.") + + if args.run_mode == "benchmark": + for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: + bench_flash_attention( + args.bs, + args.n_heads, + args.n_kvheads, + N_CTX, + args.d_head, + True, + "all", + "triton", + args, + ) # .run(save_path='.', print_data=True) + reset_global_memory_buffer() + else: + assert args.run_mode == "test" + for N_CTX in [4096]: + test_op(2, 16, N_CTX, 128, True) + # test_gqa(1, 16, 8, N_CTX, 128, True) + reset_global_memory_buffer() diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py new file mode 100644 index 000000000..35e680774 --- /dev/null +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py @@ -0,0 +1,755 @@ +""" +Materialization-aware gradient checkpointing monkey patch. +""" +from typing import List, Optional, Tuple + +import torch +import transformers +from einops import rearrange +from torch import nn +from torch.utils.checkpoint import ( + _get_autocast_kwargs, + check_backward_validity, + detach_variable, + get_device_states, + set_device_states, +) +from transformers.models.llama.modeling_llama import ( + BaseModelOutputWithPast, + apply_rotary_pos_emb, +) + +from .async_communication import initialize_distributed, reset_global_memory_buffer +from .lightseq_async_attn import _lightseq_backward, _lightseq_forward + +# define a global buffer to save flash attention outputs +# it's called global because it saves the outputs for all layers +global_flash_attn_out_buffer = None + +# define a local buffer to save recomputed qkv +# it's called local because it's a temporary buffer which will be updated across layers +local_res_grad_buffer = None + +# hooks for the gradients of residual +global_hooks = [] + + +def init_flash_attn_buffers(num_layers): + # update the global buffer according to number of layers + global global_flash_attn_out_buffer + global_flash_attn_out_buffer = [None] * num_layers + + +def clean_hook(): + # Remove all hooks in the global buffer + for hook in global_hooks: + hook.remove() + # Clear the global buffer + global_hooks.clear() + + +def clear_all_buffers_at_the_end_of_training(): + # call it at the end of training + global lobal_flash_attn_out_buffer + global_flash_attn_out_buffer = None + global local_res_grad_buffer + local_res_grad_buffer = None + clean_hook() + + +def save_flash_attn_out_to_global_buffer(idx, out): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx] = out + + +def get_flash_attn_out_from_global_buffer(idx): + global global_flash_attn_out_buffer + return global_flash_attn_out_buffer[idx] + + +def free_flash_attn_out_buffer(idx): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx] = None + + +def write_gradient_to_flash_attn_out(idx, grad): + global global_flash_attn_out_buffer + global_flash_attn_out_buffer[idx].grad = grad + + +def save_res_grad_hook(grad): + global local_res_grad_buffer + local_res_grad_buffer = grad + + +def load_and_add_res_grad_hook(grad): + grad += get_res_grad_from_local_buffer() + + +def get_res_grad_from_local_buffer(): + global local_res_grad_buffer + assert local_res_grad_buffer is not None + return local_res_grad_buffer + + +class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function): + """Avoid doing twice flash attention forward during checkpointed backward. + args: + hidden_states, # i.e., flash attention output which is saved in global buffer. + attention_mask, + position_ids, + residual, # the gradient of residual is saved in local buffer to pass across ckpt layers. + """ + + @staticmethod + def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.layer_idx = layer_idx + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_cuda_in_fwd = False + if torch.cuda._initialized: + ctx.had_cuda_in_fwd = True + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if i == 0 and ctx.layer_idx != 0: + # flash attention output is saved to the global buffer during forward + ctx.inputs.append(None) + else: + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + with torch.no_grad(): + q, k, v, residual = run_function(*args) + softmax_scale = q.shape[-1] ** (-0.5) + + # lightseq version + _, _, _, out, softmax_lse = _lightseq_forward( + q, k, v, True, softmax_scale, comm_mode="lightseq" + ) + rng_state = None + + # save flash attention output to global buffer + save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) + tensor_inputs += [softmax_lse] + ctx.softmax_scale = softmax_scale + + ctx.save_for_backward(*tensor_inputs) + + return out, residual + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" + " argument." + ) + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + tensors, softmax_lse = tensors[:-1], tensors[-1] + + # Fill in inputs with appropriate saved tensors. + # Fill the flash attention output first + if ctx.layer_idx > 0: + # inputs[0] should be flash attention output + inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx - 1) + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), torch.cuda.amp.autocast( + **ctx.gpu_autocast_kwargs + ), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + # Stop recomputation before flash attention + # It is unecessary to run recomputation for flash attn + q, k, v, residual = ctx.run_function(*detached_inputs) + + # run backward() with only tensor that requires grad + # run flash attention backward first: + # get 'dout' from auto_grad inputs + # get 'out' from global buffer + # get 'qkv' from the recomputed tensors + # dq = torch.empty(q.shape, dtype=q.dtype, device=q.device) + # dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) + # dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) + out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) + # todo get dout + dout = args[0] + + # lightseq version + dq, dk, dv = _lightseq_backward( + dout, + q, + k, + v, + out, + softmax_lse, + ctx.softmax_scale, + comm_mode="lightseq", + backward_engine="flash", + ) + # dqkv = torch.stack([dq, dk, dv]) + + # run backward for the part before flash attention + # qkv.backward(dqkv) + torch.autograd.backward([q, k, v], [dq, dk, dv]) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs + ) + + # write flash attention output gradients to buffer + if ctx.layer_idx > 0: + write_gradient_to_flash_attn_out(ctx.layer_idx - 1, detached_inputs[0].grad) + + return (None, None, None) + grads + + +def checkpoint_end_with_flash_attention( + function, layer_idx, *args, use_reentrant: bool = True, **kwargs +): + # Hack to mix *args with **kwargs in a python 2.7-compliant way + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs and use_reentrant: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + return CheckpointFunctionEndWithFlashAttention.apply( + function, layer_idx, preserve, *args + ) + + +class CheckpointFunctionLastModule(torch.autograd.Function): + """ + for the last ffn layer after flash attention, modifications include: + write the gradients wrt flash attention output and residual to the global buffer. + """ + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_cuda_in_fwd = False + if torch.cuda._initialized: + ctx.had_cuda_in_fwd = True + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + + assert torch.is_tensor( + args[0] + ), "assuming the first tensor is the flash attention output" + for i, arg in enumerate(args): + if torch.is_tensor(arg) and i == 0: + # flash attn output has been saved to global buffer + ctx.inputs.append(None) + elif torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" + " argument." + ) + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + # Fill in inputs with appropriate saved tensors. + # Fill the flash attention output first + # inputs[0] should be flash attention output + inputs[0] = get_flash_attn_out_from_global_buffer(-1) + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: + rng_devices = ctx.fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_cuda_in_fwd: + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(), torch.cuda.amp.autocast( + **ctx.gpu_autocast_kwargs + ), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True," + " this checkpoint() is not necessary" + ) + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs + ) + + # write flash attention output gradients to buffer + write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) + + return (None, None) + grads + + +def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs): + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs and use_reentrant: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + return CheckpointFunctionLastModule.apply(function, preserve, *args) + + +def llama_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + compute_attn_only: Optional[bool] = False, + compute_ffn_only: Optional[bool] = False, + residual: Optional[bool] = None, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + assert compute_ffn_only or compute_attn_only + + if compute_attn_only: + residual = hidden_states + + if residual.requires_grad: + # register a hook to add the gradient of residual + # from next checkpoint layer when doing recomputation + hook = residual.register_hook(load_and_add_res_grad_hook) + global_hooks.append(hook) + + hidden_states = self.input_layernorm(hidden_states) + + # Flash Attention + bsz, q_len, _ = hidden_states.size() + try: + query_states = ( + self.self_attn.q_proj(hidden_states) + .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.self_attn.k_proj(hidden_states) + .view( + bsz, + q_len, + self.self_attn.num_key_value_heads, + self.self_attn.head_dim, + ) + .transpose(1, 2) + ) + value_states = ( + self.self_attn.v_proj(hidden_states) + .view( + bsz, + q_len, + self.self_attn.num_key_value_heads, + self.self_attn.head_dim, + ) + .transpose(1, 2) + ) + except: + # old transformers versions don't support num_key_value_heads + query_states = ( + self.self_attn.q_proj(hidden_states) + .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.self_attn.k_proj(hidden_states) + .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.self_attn.v_proj(hidden_states) + .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, "past_key_value is not supported" + + cos, sin = self.self_attn.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + return ( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + residual, + ) + + elif compute_ffn_only: + hidden_states = self.self_attn.o_proj( + rearrange(hidden_states, "b h s d -> b s (h d)") + ) + # Need to add residual here to make sure checkpoint is right after attention + if residual.requires_grad: + # save the gradient of residual to the local buffer + # collect the hooks which should be removed after backward to avoid memory leak + hook = residual.register_hook(save_res_grad_hook) + global_hooks.append(hook) + + hidden_states = residual + hidden_states + + # Fully Connected + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + else: + raise AttributeError + + return outputs + + +def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, +): + assert cache_position is None, "cache_position is not supported" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + attention_mask = None + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + try: + logger.warning_once("***** Using fast gradient checkpointing... *****") + except: + pass + # initialize the global buffer + init_flash_attn_buffers(len(self.layers)) + + if use_cache: + try: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + except: + pass + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # apply flash-attention friendly gradient checkpointing + if self.gradient_checkpointing and self.training: + for idx in range(len(self.layers) + 1): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + def forward_first_attn_module(module): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, _ = inputs + # None for past_key_value + return module( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + compute_attn_only=True, + ) + + return custom_forward + + def forward_ffn_attn_layer(module1, module2): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, residual = inputs + # None for past_key_value + layer_outputs = module1( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + compute_ffn_only=True, + residual=residual, + ) + hidden_states = layer_outputs[0] + return module2( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + compute_attn_only=True, + ) + + return custom_forward + + def forward_last_ffn_module(module): + def custom_forward(*inputs): + hidden_states, attention_mask, position_ids, residual = inputs + # None for past_key_value + return module( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + compute_ffn_only=True, + residual=residual, + ) + + return custom_forward + + if idx == 0: + layer_outputs = checkpoint_end_with_flash_attention( + forward_first_attn_module(self.layers[0]), + idx, + hidden_states, + attention_mask, + position_ids, + None, + ) + hidden_states, residual = layer_outputs[0], layer_outputs[-1] + elif idx == len(self.layers): + layer_outputs = checkpoint_last_module( + forward_last_ffn_module(self.layers[-1]), + hidden_states, + attention_mask, + position_ids, + residual, + ) + hidden_states = layer_outputs[0] + else: + layer_outputs = checkpoint_end_with_flash_attention( + forward_ffn_attn_layer(self.layers[idx - 1], self.layers[idx]), + idx, + hidden_states, + attention_mask, + position_ids, + residual, + ) + hidden_states, residual = layer_outputs[0], layer_outputs[-1] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + else: + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def apply_dist_flash_attn_monkey_patch_llama(): + initialize_distributed() + transformers.models.llama.modeling_llama.LlamaModel.forward = forward + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + llama_layer_forward + ) diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py b/src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py new file mode 100644 index 000000000..64245a143 --- /dev/null +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py @@ -0,0 +1,34 @@ +def extract_local(value, rank, world_size, device, dim=1): + value_local = value.chunk(world_size, dim=dim)[rank] + return value_local.to(device) + + +def prepare_dist_flash_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } diff --git a/src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py new file mode 100644 index 000000000..8221f6f91 --- /dev/null +++ b/src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py @@ -0,0 +1,114 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +import transformers + +try: + from yunchang.ulysses import UlyssesAttention + + ulysses_attn = UlyssesAttention() +except: + print( + "If you want to use the UlyssesAttention class, please install the yunchang package." + ) + ulysses_attn = None + + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, +): + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = ulysses_attn( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + ) + + return attn_output + + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def apply_ulysses_attn_monkey_patch_llama(): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) diff --git a/src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py b/src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py new file mode 100644 index 000000000..b05716d0c --- /dev/null +++ b/src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py @@ -0,0 +1,44 @@ +import torch + + +def extract_local(value, rank, world_size, device, dim=1): + dimension_size = value.shape[dim] + sub_seq_length = dimension_size // world_size + + sub_seq_start = rank * sub_seq_length + sub_seq_end = (rank + 1) * sub_seq_length + local_value = value[:, sub_seq_start:sub_seq_end] + + return local_value.to(device) + + +def prepare_ulysses_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } diff --git a/src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py b/src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py new file mode 100644 index 000000000..5d5d6a5a7 --- /dev/null +++ b/src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py @@ -0,0 +1,95 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch +import transformers + + +class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): + """ + Saves VRAM by smartly offloading to RAM. + Tiny hit to performance, since we mask the movement via non blocking calls. + """ + + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, forward_function, hidden_states, *args): + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) + with torch.no_grad(): + output = forward_function(hidden_states, *args) + ctx.save_for_backward(saved_hidden_states) + ctx.forward_function = forward_function + ctx.args = args + + return output + + pass + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, dY): + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() + hidden_states.requires_grad = True + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) + torch.autograd.backward(output, dY) + return ( + None, + hidden_states.grad, + ) + ( + None, + ) * len(ctx.args) + + pass + + +pass + + +def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + assert gradient_checkpointing_kwargs == None + if not self.supports_gradient_checkpointing: + raise ValueError( + f"{self.__class__.__name__} does not support gradient checkpointing." + ) + + gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = ( + "value" in inspect.signature(self._set_gradient_checkpointing).parameters + ) + + if not _is_using_old_format: + self._set_gradient_checkpointing( + enable=True, gradient_checkpointing_func=gradient_checkpointing_func + ) + else: + raise NotImplementedError() + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + +def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): + transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( + new_gradient_checkpointing_enable + ) diff --git a/src/axolotl/integrations/easy_context/usp/monkey_patch.py b/src/axolotl/integrations/easy_context/usp/monkey_patch.py new file mode 100644 index 000000000..9745036f3 --- /dev/null +++ b/src/axolotl/integrations/easy_context/usp/monkey_patch.py @@ -0,0 +1,114 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +import transformers + +try: + from yunchang import LongContextAttention, set_seq_parallel_pg + + usp_attn = LongContextAttention(ring_impl_type="zigzag") +except: + print( + "If you want to use the LongContextAttention class, please install the yunchang package." + ) + usp_attn = None + + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, +): + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = usp_attn( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + ) + + return attn_output + + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def apply_usp_attn_monkey_patch_llama(): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) diff --git a/src/axolotl/integrations/easy_context/usp/prepare_inputs.py b/src/axolotl/integrations/easy_context/usp/prepare_inputs.py new file mode 100644 index 000000000..c2c4f2fad --- /dev/null +++ b/src/axolotl/integrations/easy_context/usp/prepare_inputs.py @@ -0,0 +1,58 @@ +import torch +from yunchang import set_seq_parallel_pg +from yunchang.comm import zigzag_extract_local + + +def prepare_usp_attn_inputs( + input_ids, + position_ids, + target_ids, + rank, + world_size, + device, + ring_degree, + ulysses_degree, +): + f""" + prepare input for USP attention + + USP: A Unified Sequence Parallelism Approach for Long Context Generative AI + https://arxiv.org/abs/2405.07719 + """ + + set_seq_parallel_pg(ulysses_degree, ring_degree, rank, world_size) + + local_input_ids = zigzag_extract_local( + input_ids, + rank, + world_size, + ring_degree, + ulysses_degree, + ).to(device) + + # truncate position_ids to the same size as input_ids + position_ids = position_ids[:, : local_input_ids.shape[1]] + + local_position_ids = zigzag_extract_local( + position_ids, + rank, + world_size, + ring_degree, + ulysses_degree, + ).to(device) + + if target_ids is not None: + local_target_ids = zigzag_extract_local( + target_ids, + rank, + world_size, + ring_degree, + ulysses_degree, + ).to(device) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } diff --git a/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py new file mode 100644 index 000000000..67dd05631 --- /dev/null +++ b/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py @@ -0,0 +1,114 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +import transformers +from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func + + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, +): + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = zigzag_ring_flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + ) + + return attn_output + + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance( + self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def apply_zigzag_ring_attn_monkey_patch_llama(): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( + new_decoder_forward + ) + + +def apply_zigzag_ring_attn_monkey_patch_mistral(): + transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( + new_decoder_forward + ) diff --git a/src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py b/src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py new file mode 100644 index 000000000..f5cea9064 --- /dev/null +++ b/src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py @@ -0,0 +1,40 @@ +import torch + + +def extract_local(value, rank, world_size, device, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat( + [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim + ) + return local_value.to(device) + + +def prepare_zigzag_ring_attn_inputs( + input_ids, position_ids, target_ids, rank, world_size, device +): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } diff --git a/src/axolotl/monkeypatch/sequence_parallel.py b/src/axolotl/monkeypatch/sequence_parallel.py new file mode 100644 index 000000000..b3caefb14 --- /dev/null +++ b/src/axolotl/monkeypatch/sequence_parallel.py @@ -0,0 +1,123 @@ +""" +Utilities for sequence parallelism implementation. + +Modified from: +https://github.com/Qihoo360/360-LLaMA-Factory/blob/f295a5760cceebe069fb5b975813d2c945598acb/src/llamafactory/model/model_utils/sequence_parallel.py +""" + +from functools import partial + +import torch.distributed as dist +import transformers +import transformers.modeling_attn_mask_utils +from ring_flash_attn import ( + ring_flash_attn_func, + stripe_flash_attn_func, + zigzag_ring_flash_attn_func, +) + + +def ring_flash_attn_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0, + sliding_window=None, + is_causal=True, + group=None, + **kwargs, +): + attn_output = ring_flash_attn_func( + query_states, key_states, value_states, dropout, causal=is_causal, group=group + ) + + return attn_output + + +def zigzag_flash_attn_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0, + sliding_window=None, + is_causal=True, + group=None, + **kwargs, +): + attn_output = zigzag_ring_flash_attn_func( + query_states, key_states, value_states, dropout, causal=is_causal, group=group + ) + + return attn_output + + +def stripe_flash_attn_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0, + sliding_window=None, + is_causal=True, + group=None, + **kwargs, +): + attn_output = stripe_flash_attn_func( + query_states, key_states, value_states, dropout, causal=is_causal, group=group + ) + + return attn_output + + +def init_sp_group(sp_size): + assert dist.is_initialized() + world_size = dist.get_world_size() + assert ( + world_size % sp_size == 0 + ), "Total number of GPUs must be a multiple of sequence_parallel_size." + + sp_group_num = world_size // sp_size + sp_ranks_list = [ + list(range(i * sp_size, i * sp_size + sp_size)) for i in range(sp_group_num) + ] + + sp_groups = [dist.new_group(sp_ranks_this) for sp_ranks_this in sp_ranks_list] + + global_rank_this = dist.get_rank() + sp_idx = global_rank_this // sp_size + return sp_groups[sp_idx] + + +def apply_sequence_parallel(cfg): + if cfg.sequence_parallel_size == 1: + return None # no sequence parallelism + + # init sequence-parallel groups here + group_this = init_sp_group(cfg.sequence_parallel_size) + + if cfg.sequence_parallel_mode == "ring": + new_flash_attention_forward = partial(ring_flash_attn_forward, group=group_this) + elif cfg.sequence_parallel_mode == "zigzag-ring": + new_flash_attention_forward = partial( + zigzag_flash_attn_forward, group=group_this + ) + elif cfg.sequence_parallel_mode == "stripe": + new_flash_attention_forward = partial( + stripe_flash_attn_forward, group=group_this + ) + else: + raise NotImplementedError( + "Other sequence parallel modes are to be implemented." + ) + + # monkey patching + transformers.modeling_flash_attention_utils._flash_attention_forward = ( + new_flash_attention_forward + ) + + return group_this diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 44f570b88..0682b81b5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -547,6 +547,20 @@ class ModelLoader: patch_self_attn_lora(self.cfg) + if self.cfg.sequence_parallel_size > 1: + from axolotl.integrations.easy_context import ( + apply_seq_parallel_monkey_patch, + ) + + method = self.cfg.sequence_parallel_mode + model_type = self.cfg.model_type + + # Apply the monkey patch + apply_seq_parallel_monkey_patch(method, model_type) + + # Ensure flash attention is enabled when loading the model + self.cfg.attn_implementation = "flash_attention_2" + def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: