From 7efc787ac8b8998cae3901b78c1d42e5a425023f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 5 Mar 2025 14:53:18 +0000 Subject: [PATCH] cleanup --- .../integrations/easy_context/__init__.py | 96 -- .../easy_context/dist_flash_attn/README.md | 11 - .../dist_flash_attn/async_communication.py | 776 ---------- .../dist_flash_attn/lightseq_async_attn.py | 1209 ---------------- .../lightseq_async_attn_varlen.py | 1273 ----------------- .../dist_flash_attn/monkey_patch.py | 754 ---------- .../dist_flash_attn/prepare_input.py | 34 - .../easy_context/ulysses_attn/monkey_patch.py | 114 -- .../ulysses_attn/prepare_inputs.py | 44 - .../monkey_patch.py | 95 -- .../easy_context/usp/monkey_patch.py | 114 -- .../easy_context/usp/prepare_inputs.py | 58 - .../zigzag_ring_attn/monkey_patch.py | 106 -- .../zigzag_ring_attn/prepare_inputs.py | 40 - src/axolotl/utils/models.py | 10 - src/axolotl/utils/trainer.py | 12 - 16 files changed, 4746 deletions(-) delete mode 100644 src/axolotl/integrations/easy_context/__init__.py delete mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/README.md delete mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py delete mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py delete mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py delete mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py delete mode 100644 src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py delete mode 100644 src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py delete mode 100644 src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py delete mode 100644 src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py delete mode 100644 src/axolotl/integrations/easy_context/usp/monkey_patch.py delete mode 100644 src/axolotl/integrations/easy_context/usp/prepare_inputs.py delete mode 100644 src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py delete mode 100644 src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py diff --git a/src/axolotl/integrations/easy_context/__init__.py b/src/axolotl/integrations/easy_context/__init__.py deleted file mode 100644 index 93d38a2e9..000000000 --- a/src/axolotl/integrations/easy_context/__init__.py +++ /dev/null @@ -1,96 +0,0 @@ -import logging - -from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama -from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs -from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama -from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs -from .usp.monkey_patch import apply_usp_attn_monkey_patch_llama -from .usp.prepare_inputs import prepare_usp_attn_inputs -from .zigzag_ring_attn.monkey_patch import ( - apply_zigzag_ring_attn_monkey_patch_llama, - apply_zigzag_ring_attn_monkey_patch_mistral, -) -from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs - -logger = logging.getLogger(__name__) - - -def prepare_seq_parallel_inputs( - seq_algo, - input_ids, - position_ids, - target_ids, - rank, - world_size, - device, - *args, - **kwargs, -): - if seq_algo == "zigzag_ring_attn": - return prepare_zigzag_ring_attn_inputs( - input_ids, position_ids, target_ids, rank, world_size, device - ) - elif seq_algo == "dist_flash_attn": - return prepare_dist_flash_attn_inputs( - input_ids, position_ids, target_ids, rank, world_size, device - ) - elif seq_algo == "ulysses_attn": - return prepare_ulysses_attn_inputs( - input_ids, position_ids, target_ids, rank, world_size, device - ) - elif seq_algo == "usp_attn": - ring_degree = kwargs.get("ring_degree", 1) - ulysses_degree = world_size // ring_degree - logger.info( - f"Applying USP: Ring degree: {ring_degree}, Ulysses degree: {ulysses_degree}" - ) - return prepare_usp_attn_inputs( - input_ids, - position_ids, - target_ids, - rank, - world_size, - device, - ulysses_degree, - ring_degree, - ) - elif seq_algo == "data_parallel": - return { - "local_input_ids": input_ids.to(device), - "local_position_ids": position_ids.to(device), - "local_target_ids": target_ids.to(device), - } - else: - raise ValueError(f"Invalid seq_algo: {seq_algo}") - - -def apply_seq_parallel_monkey_patch(seq_algo, model): - assert seq_algo in [ - "zigzag_ring_attn", - "dist_flash_attn", - "ulysses_attn", - "data_parallel", - "usp_attn", - ], f"Invalid seq_algo: {seq_algo}" - assert model in ["llama", "mistral"], f"Invalid model: {model}" - if seq_algo == "data_parallel": - return - elif seq_algo == "zigzag_ring_attn" and model == "llama": - apply_zigzag_ring_attn_monkey_patch_llama() - elif seq_algo == "zigzag_ring_attn" and model == "mistral": - apply_zigzag_ring_attn_monkey_patch_mistral() - elif seq_algo == "dist_flash_attn" and model == "llama": - apply_dist_flash_attn_monkey_patch_llama() - elif seq_algo == "ulysses_attn" and model == "llama": - apply_ulysses_attn_monkey_patch_llama() - elif seq_algo == "usp_attn" and model == "llama": - apply_usp_attn_monkey_patch_llama() - else: - raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}") - - -def prepare_dataloader(seq_algo, dataloader, accelerator): - if seq_algo == "data_parallel": - return accelerator.prepare(dataloader) - else: - return dataloader diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/README.md b/src/axolotl/integrations/easy_context/dist_flash_attn/README.md deleted file mode 100644 index 9ddeb332e..000000000 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# LightSeq -Taken from https://github.com/RulinShao/LightSeq. All credits to the authors. - -``` -@article{li2023lightseq, - title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS}, - author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao}, - journal={arXiv preprint arXiv:2310.03294}, - year={2023} -} -``` diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py b/src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py deleted file mode 100644 index 1ea4aded1..000000000 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/async_communication.py +++ /dev/null @@ -1,776 +0,0 @@ -import math -import os -import threading - -import torch -import torch.distributed as dist -from torch.distributed import P2POp, batch_isend_irecv, irecv, isend - -# Sequence parallel group that the current rank belongs to. -_SEQUENCE_PARALLEL_GROUP = None - -# These values enable us to change the sequence parallel sizes on the fly. -_SEQUENCE_PARALLEL_SIZE = None -_SEQUENCE_PARALLEL_RANK = None - -# Global buffer for P2P -_PEER_Q = None -_PEER_K = None -_PEER_V = None -_PEER_M = None -_PEER_L = None -_PEER_O = None -_PEER_Q_BWD = None -_PEER_K_BWD = None -_PEER_V_BWD = None -_PEER_O_BWD = None - -_DELTA_DQ = None -_PEER_L = None -_DELTA_DK = None -_DELTA_DV = None -_DK_DELTA_FROM_PEER = None -_DV_DELTA_FROM_PEER = None -_PEER_DO = None - - -_fwd_send_volume = 0 -_fwd_recv_volume = 0 -_bwd_send_volume = 0 -_bwd_recv_volume = 0 - - -def initialize_distributed(): - if dist.is_initialized(): - if dist.get_rank() == 0: - print( - "torch distributed is already initialized, " - "skipping initialization ...", - flush=True, - ) - else: - if int(os.environ["RANK"]) == 0: - print("Initializing Torch distributed.") - dist.init_process_group(backend="nccl") - local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - global_world_size = dist.get_world_size() - torch.cuda.set_device(dist.get_rank() % local_world_size) - - _initialize_sequence_parallel() - - -# create_nccl_communicators() - - -def _initialize_sequence_parallel(sequence_parallel_size=None): - # Get world size and rank. Ensure some consistencies. - assert ( - sequence_parallel_size is None - ), "Multiple sequence parallel group not implemented." - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - - if sequence_parallel_size is None: - sequence_parallel_size = world_size - else: - assert world_size % sequence_parallel_size == 0 - num_sequence_parallel_groups: int = world_size // sequence_parallel_size - - rank = torch.distributed.get_rank() - - # Build the sequence parallel groups. - global _SEQUENCE_PARALLEL_GROUP - global _SEQUENCE_PARALLEL_RANK - global _SEQUENCE_PARALLEL_SIZE - - assert ( - _SEQUENCE_PARALLEL_GROUP is None - ), "sequence parallel group is already initialized" - for i in range(num_sequence_parallel_groups): - ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) - group = torch.distributed.new_group(ranks) - if rank in ranks: - _SEQUENCE_PARALLEL_GROUP = group - _SEQUENCE_PARALLEL_RANK = ranks.index(rank) - _SEQUENCE_PARALLEL_SIZE = len(ranks) - - if dist.get_rank() == 0: - print("************ Finish sequence pralell group Initialization. ***********") - # _set_global_memory_buffer() - - -def maybe_get_set_global_memory_buffer(q, k, v, m, l, o): - global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O - if _PEER_Q is None: - try: - if get_sequence_parallel_rank() == 0: - print("Initializing global memoery buffer.") - except: - print("Initializing global memoery buffer.") - _PEER_Q = [torch.empty_like(q) for _ in range(2)] - _PEER_K = [torch.empty_like(k) for _ in range(2)] - _PEER_V = [torch.empty_like(v) for _ in range(2)] - _PEER_M = [torch.empty_like(m) for _ in range(2)] - _PEER_L = [torch.empty_like(l) for _ in range(2)] - _PEER_O = [torch.empty_like(o) for _ in range(2)] - - return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O - - -def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do): - global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO - if _DELTA_DQ is None: - try: - if get_sequence_parallel_rank() == 0: - print("Initializing global memoery buffer for backward.") - except: - print("Initializing global memoery buffer for backward.") - _DELTA_DQ = [torch.empty_like(dq) for _ in range(2)] - _DELTA_DK = [torch.empty_like(dk) for _ in range(2)] - _DELTA_DV = [torch.empty_like(dv) for _ in range(2)] - _PEER_L = [torch.empty_like(L) for _ in range(2)] - - _DK_DELTA_FROM_PEER = torch.empty_like(dk) - _DV_DELTA_FROM_PEER = torch.empty_like(dv) - - # may already be initailized in the forward call. - # current forward and backward needs a transpose in q's format - _PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)] - _PEER_K_BWD = [torch.empty_like(k) for _ in range(2)] - _PEER_V_BWD = [torch.empty_like(v) for _ in range(2)] - _PEER_O_BWD = [torch.empty_like(o) for _ in range(2)] - - _PEER_DO = [torch.empty_like(do) for _ in range(2)] - - return ( - _DELTA_DQ, - _DELTA_DK, - _DELTA_DV, - _DK_DELTA_FROM_PEER, - _DV_DELTA_FROM_PEER, - _PEER_Q_BWD, - _PEER_L, - _PEER_K_BWD, - _PEER_V_BWD, - _PEER_O_BWD, - _PEER_DO, - ) - - -def reset_global_memory_buffer(): - global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO - _PEER_Q = None - _PEER_K = None - _PEER_V = None - _PEER_M = None - _PEER_L = None - _PEER_O = None - - _DELTA_DQ = None - _PEER_L = None - _DELTA_DK = None - _DELTA_DV = None - _DK_DELTA_FROM_PEER = None - _DV_DELTA_FROM_PEER = None - _PEER_DO = None - - -# Pytorch defers the creation of nccl communicators to the first P2P call, -# We manually create them so the first isend does not hang without an irecv. -# reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138 -# Only support even number of GPUs. -def create_nccl_communicators(): - seq_rank = get_sequence_parallel_rank() - seq_group = get_sequence_parallel_group() - - empty_tensor = torch.empty( - 1, - ).cuda() - empty_tensor_2 = torch.empty( - 1, - ).cuda() - if torch.distributed.get_rank() % 2 == 0: - # sender - op1 = P2POp( - op=isend, - tensor=torch.empty( - 1, - ).cuda(), - peer=seq_rank + 1, - group=seq_group, - ) - op2 = P2POp( - op=irecv, - tensor=torch.empty( - 1, - ).cuda(), - peer=seq_rank + 1, - group=seq_group, - ) - # req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) - dist.batch_isend_irecv([op1, op2]) - else: - # receiver - op1 = P2POp( - op=irecv, - tensor=torch.empty( - 1, - ).cuda(), - peer=seq_rank - 1, - group=seq_group, - ) - op2 = P2POp( - op=isend, - tensor=torch.empty( - 1, - ).cuda(), - peer=seq_rank - 1, - group=seq_group, - ) - # req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) - handles = dist.batch_isend_irecv([op1, op2]) - # req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group) - dist.all_reduce(empty_tensor, group=seq_group) - - -def get_sequence_parallel_group(): - """Get the sequence parallel group the caller rank belongs to.""" - # global _SEQUENCE_PARALLEL_GROUP - assert ( - _SEQUENCE_PARALLEL_GROUP is not None - ), "sequence parallel group is not initialized" - return _SEQUENCE_PARALLEL_GROUP - - -def get_sequence_parallel_rank(): - """Return my rank for the sequence parallel group.""" - global _SEQUENCE_PARALLEL_RANK - if _SEQUENCE_PARALLEL_RANK is not None: - return _SEQUENCE_PARALLEL_RANK - return torch.distributed.get_rank(group=get_sequence_parallel_group()) - - -def get_sequence_parallel_size(): - """Return my rank for the sequence parallel group.""" - global _SEQUENCE_PARALLEL_SIZE - if _SEQUENCE_PARALLEL_SIZE is not None: - return _SEQUENCE_PARALLEL_SIZE - return torch.distributed.get_world_size(group=get_sequence_parallel_group()) - - -def destroy_sequence_parallel(): - """Set the groups to none.""" - global _SEQUENCE_PARALLEL_GROUP - _SEQUENCE_PARALLEL_GROUP = None - - -# whether this is the last time the kernel being called -def is_last_time(time_step): - # e.g. on a 8-GPU setup: - # R=0: 0 - # R=1: 1 - # R=2: 2 - # R=3: 3 - # R=4: 4, 5, 6, 7 - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - if seq_rank <= seq_world_size // 2: # no one helps these ranks - rank_finish_time = seq_rank - else: - rank_finish_time = seq_world_size // 2 - return rank_finish_time == time_step - - -# Whether the current time step is computing for local q -def is_compute_for_local_query(time_step): - # R=3,4,5,6,7: Yes - # R=0: 0 - # R=1: 0, 1 - # R=2: 0, 1, 2 - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - if seq_rank >= min(seq_world_size // 2, time_step): - return True - return False - - -# Whether the current time step is idle -def is_idle(time_step): - # 0, 1, 2, 3: 4 - # 4, 5, 6, 7: No - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2: - return True - return False - - -# Whether the current time step needs to synchronize with a remote computed result -def is_sync_from_remote(time_step): - # R=0, 1, 2, 3, 4: No - # R=5: 4 - # R=6: 3, 4 - # R=7: 2, 3, 4 - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - if seq_rank > max(seq_world_size // 2, seq_world_size - time_step): - return True - return False - - -def maybe_send_recv_fwd_qkvo( - q: torch.Tensor, - peer_q: torch.Tensor, - k: torch.Tensor, - peer_k: torch.Tensor, - v: torch.Tensor, - peer_v: torch.Tensor, - o_stats: list, # peer_o_stats: list, - time_step: int, - comm_mode, - debug=False, -) -> torch.Tensor: - seq_group = get_sequence_parallel_group() - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - # Handles for operations that actually need to be wait before going to the next iteration. - # For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler; - all_handles = [] - # KV logic: different than older version, every rank to send/recv its own kv, - # to balance communication. In a balanced communication, every step each rank - # should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when - # time step > 0, should send its own kv and send/recv qo. In the older version, - # rank 0 does not send its kv, and rely on a later rank to pass it, where the - # later rank has to (1) receive kv, send rank 0's kv and send/recv qo. - # Q (load balancing) logic: semantically, this will be "%" world size, so - # the same send/recv rank as KV. Note: Only support even number of machines. - # O (load balancing) logic: rank 0 sends result to rank 7 at time 1. - # It get delayed for one time step, and thus has different maybe_send/recv_rank. - # Use (time_step + 1) to easily convert to synchornize version. - maybe_send_rank = seq_rank + (time_step + 1) - maybe_recv_rank = seq_rank - (time_step + 1) - - if debug: - global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume - _debug_send = _fwd_send_volume - _debug_recv = _fwd_recv_volume - - if maybe_send_rank >= seq_world_size: - # send q, no one needs to do remote computation in the last time step - if time_step < (seq_world_size // 2 - 1): - # print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)") - # q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) - all_handles.append( - P2POp( - op=isend, - tensor=q, - peer=maybe_send_rank % seq_world_size, - group=seq_group, - ) - ) - if debug: - _fwd_send_volume += torch.numel(q) * q.element_size() - else: - # send kv - # print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)") - # kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) - # kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) - all_handles.append( - P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group) - ) - all_handles.append( - P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group) - ) - if debug: - _fwd_send_volume += torch.numel(k) * k.element_size() - _fwd_send_volume += torch.numel(v) * v.element_size() - - if maybe_recv_rank < 0: - # recv q, no one needs to do remote computation in the last time step - if time_step < (seq_world_size // 2 - 1): - # print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)") - # q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) - all_handles.append( - P2POp( - op=irecv, - tensor=peer_q, - peer=maybe_recv_rank % seq_world_size, - group=seq_group, - ) - ) - if debug: - _fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() - else: - # recv kv - # print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)") - # kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) - # kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) - all_handles.append( - P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group) - ) - all_handles.append( - P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group) - ) - if debug: - _fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() - _fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() - - maybe_send_rank_o = seq_rank - (time_step - 1) - maybe_recv_rank_o = seq_rank + (time_step - 1) - if maybe_send_rank_o < 0 and time_step > 1: - for t in o_stats: - # print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)") - # o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) - all_handles.append( - P2POp( - op=isend, - tensor=t, - peer=maybe_send_rank_o % seq_world_size, - group=seq_group, - ) - ) - if debug: - _fwd_send_volume += torch.numel(t) * t.element_size() - if maybe_recv_rank_o >= seq_world_size and time_step > 1: - for t in o_stats: - # print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)") - # o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) - all_handles.append( - P2POp( - op=irecv, - tensor=t, - peer=maybe_recv_rank_o % seq_world_size, - group=seq_group, - ) - ) - if debug: - _fwd_recv_volume += torch.numel(t) * t.element_size() - - # reqs = [] - - if debug: - if seq_rank in [0, 8]: - print( - f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB" - ) - # return reqs - all_reqs = launch_async_handles(all_handles, comm_mode) - return [all_reqs] - - -# delta: may be you are using it for your local compute or as a distributed buffer to send to others -# .. Sorry for the bad naming.. -def maybe_send_recv_bwd_qkvo( - dq_delta: torch.Tensor, - dk_delta: torch.Tensor, - dv_delta: torch.Tensor, - dk_delta_from_peer: torch.Tensor, - dv_delta_from_peer: torch.Tensor, - q: torch.Tensor, - peer_q: torch.Tensor, - L: torch.Tensor, - peer_L: torch.Tensor, - k: torch.Tensor, - peer_k: torch.Tensor, - v: torch.Tensor, - peer_v: torch.Tensor, - o: torch.Tensor, - peer_o: torch.Tensor, - do: torch.Tensor, - peer_do: torch.Tensor, - time_step: int, - comm_mode, - debug=False, -): - seq_group = get_sequence_parallel_group() - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - all_handles = [] - maybe_send_rank = seq_rank + (time_step + 1) - maybe_recv_rank = seq_rank - (time_step + 1) - - if debug: - global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume - - if maybe_send_rank >= seq_world_size: - # send q, no one needs to do remote computation in the last time step - if time_step < (seq_world_size // 2 - 1): - all_handles.append( - P2POp( - op=isend, - tensor=q, - peer=maybe_send_rank % seq_world_size, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=isend, - tensor=L, - peer=maybe_send_rank % seq_world_size, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=isend, - tensor=o, - peer=maybe_send_rank % seq_world_size, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=isend, - tensor=do, - peer=maybe_send_rank % seq_world_size, - group=seq_group, - ) - ) - if debug: - _bwd_send_volume += torch.numel(q) * q.element_size() - _bwd_send_volume += torch.numel(L) * L.element_size() - _bwd_send_volume += torch.numel(o) * o.element_size() - _bwd_send_volume += torch.numel(do) * do.element_size() - else: - # send kv - all_handles.append( - P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group) - ) - all_handles.append( - P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group) - ) - if debug: - _bwd_send_volume += torch.numel(k) * k.element_size() - _bwd_send_volume += torch.numel(v) * v.element_size() - - if maybe_recv_rank < 0: - # recv q, no one needs to do remote computation in the last time step - if time_step < (seq_world_size // 2 - 1): - all_handles.append( - P2POp( - op=irecv, - tensor=peer_q, - peer=maybe_recv_rank % seq_world_size, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=irecv, - tensor=peer_L, - peer=maybe_recv_rank % seq_world_size, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=irecv, - tensor=peer_o, - peer=maybe_recv_rank % seq_world_size, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=irecv, - tensor=peer_do, - peer=maybe_recv_rank % seq_world_size, - group=seq_group, - ) - ) - if debug: - _bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() - _bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size() - _bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size() - _bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size() - else: - # recv kv - all_handles.append( - P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group) - ) - all_handles.append( - P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group) - ) - if debug: - _bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() - _bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() - - # Whether I should update dq, dk and dv after waiting these requests - is_update_dq = False - is_update_dkv = False - - maybe_send_rank_dqkv = seq_rank - (time_step - 1) - maybe_recv_rank_dqkv = seq_rank + (time_step - 1) - - if time_step > 1: - if maybe_send_rank_dqkv < 0: - # print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}") - all_handles.append( - P2POp( - op=isend, - tensor=dq_delta, - peer=maybe_send_rank_dqkv % seq_world_size, - group=seq_group, - ) - ) - if debug: - _bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size() - else: - # print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}") - all_handles.append( - P2POp( - op=isend, - tensor=dk_delta, - peer=maybe_send_rank_dqkv, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=isend, - tensor=dv_delta, - peer=maybe_send_rank_dqkv, - group=seq_group, - ) - ) - if debug: - _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() - _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() - - if maybe_recv_rank_dqkv >= seq_world_size: - # print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}") - all_handles.append( - P2POp( - op=irecv, - tensor=dq_delta, - peer=maybe_recv_rank_dqkv % seq_world_size, - group=seq_group, - ) - ) - is_update_dq = True - if debug: - _bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size() - else: - # print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}") - all_handles.append( - P2POp( - op=irecv, - tensor=dk_delta_from_peer, - peer=maybe_recv_rank_dqkv, - group=seq_group, - ) - ) - all_handles.append( - P2POp( - op=irecv, - tensor=dv_delta_from_peer, - peer=maybe_recv_rank_dqkv, - group=seq_group, - ) - ) - is_update_dkv = True - if debug: - _bwd_recv_volume += ( - torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size() - ) - _bwd_recv_volume += ( - torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size() - ) - - # return [], is_update_dq, is_update_dkv - all_reqs = launch_async_handles(all_handles, comm_mode) - return [all_reqs], is_update_dq, is_update_dkv - - -def maybe_send_recv_bwd_last_dkv( - dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False -): - is_update_last_dkv = False - - seq_group = get_sequence_parallel_group() - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - if seq_world_size == 1: - return [], is_update_last_dkv - - all_handles = [] - - if debug: - global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume - - if time_step == seq_world_size // 2: - maybe_send_rank = seq_rank - time_step - maybe_recv_rank = seq_rank + time_step - - assert (maybe_send_rank >= 0) ^ ( - maybe_recv_rank < seq_world_size - ), "R={seq_rank} should be either sending or receiving dkv in the last time step." - - if maybe_send_rank >= 0: - # print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}") - all_handles.append( - P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group) - ) - all_handles.append( - P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group) - ) - if debug: - _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() - _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() - if maybe_recv_rank < seq_world_size: - # print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}") - all_handles.append( - P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group) - ) - all_handles.append( - P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group) - ) - if debug: - _bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size() - _bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size() - is_update_last_dkv = True - - # return [], is_update_last_dkv - all_reqs = launch_async_handles(all_handles, comm_mode) - - return [all_reqs], is_update_last_dkv - - -def print_and_reset_comm_stats(): - seq_rank = get_sequence_parallel_rank() - - global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume - _fwd_send_volume *= 1e-9 - _fwd_recv_volume *= 1e-9 - _bwd_send_volume *= 1e-9 - _bwd_recv_volume *= 1e-9 - - print( - f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB." - ) - _fwd_send_volume = 0 - _fwd_recv_volume = 0 - _bwd_send_volume = 0 - _bwd_recv_volume = 0 - - -def launch_async_handles(handles, comm_mode): - global _args - if comm_mode == "nocomm": - # print("skipping communication for ablation") - return [] - if len(handles) > 0: - return dist.batch_isend_irecv(handles) - return [] - - -def wait_async_handles(reqs): - if len(reqs) > 0: - for req in reqs: - for r in req: - r.wait() diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py b/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py deleted file mode 100644 index e7df1d88c..000000000 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn.py +++ /dev/null @@ -1,1209 +0,0 @@ -import argparse - -# from torch.profiler import profile, record_function, ProfilerActivity -import functools -import math -import os -import time - -import numpy as np -import pytest -import torch -import torch.distributed as dist -import triton -import triton.language as tl -from einops import rearrange -from torch.distributed import ReduceOp -from tqdm import tqdm - -try: - from flash_attn.flash_attn_interface import ( - _flash_attn_backward, - _flash_attn_forward, - ) -except: - pass - -from .async_communication import ( - get_sequence_parallel_rank, - get_sequence_parallel_size, - initialize_distributed, - is_compute_for_local_query, - is_idle, - is_last_time, - is_sync_from_remote, - launch_async_handles, - maybe_get_set_global_memory_buffer, - maybe_get_set_global_memory_buffer_bwd, - maybe_send_recv_bwd_last_dkv, - maybe_send_recv_bwd_qkvo, - maybe_send_recv_fwd_qkvo, - print_and_reset_comm_stats, - reset_global_memory_buffer, - wait_async_handles, -) - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def _rescale_kernel( - peer_m, - m, - peer_l, - l, - peer_o, - o, - L, - stride_oz, - stride_oh, - stride_om, - stride_on, - Z, - H, - N_CTX, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - LAST_STEP: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - o_offset = off_hz * stride_oh - peer_o_block_ptr = tl.make_block_ptr( - base=peer_o + o_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - o_block_ptr = tl.make_block_ptr( - base=o + o_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - - peer_m_ptrs = peer_m + off_hz * N_CTX + offs_m - m_ptrs = m + off_hz * N_CTX + offs_m - peer_l_ptrs = peer_l + off_hz * N_CTX + offs_m - l_ptrs = l + off_hz * N_CTX + offs_m - - peer_m_i = tl.load(peer_m_ptrs) - peer_m_i = peer_m_i.to(tl.float32) - m_i = tl.load(m_ptrs) - m_i = m_i.to(tl.float32) - peer_l_i = tl.load(peer_l_ptrs) - peer_l_i = peer_l_i.to(tl.float32) - l_i = tl.load(l_ptrs) - l_i = l_i.to(tl.float32) - - peer_acc = tl.load(peer_o_block_ptr) - peer_acc = peer_acc.to(tl.float32) - acc = tl.load(o_block_ptr) - acc = acc.to(tl.float32) - lo = 0 - hi = N_CTX - m_i_sync = tl.maximum(m_i, peer_m_i) - alpha = tl.math.exp2(m_i - m_i_sync) - peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug - - acc *= acc_scale[:, None] - peer_acc *= peer_acc_scale[:, None] - acc += peer_acc - l_i = l_i * acc_scale + peer_l_i * peer_acc_scale - # write back O, l, m - tl.store(m_ptrs, m_i_sync) - tl.store(l_ptrs, l_i) - if LAST_STEP: - acc = acc / l_i[:, None] - L_ptrs = L + off_hz * N_CTX + offs_m - tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) - tl.store(o_block_ptr, acc.to(tl.bfloat16)) - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - m, - l, - O, - L, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - Z, - H, - N_CTX, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, - LAST_STEP: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qvk_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - O_block_ptr = tl.make_block_ptr( - base=O + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l -> load from provided pointer - m_ptrs = m + off_hz * N_CTX + offs_m - l_ptrs = l + off_hz * N_CTX + offs_m - m_i = tl.load(m_ptrs) - m_i = m_i.to(tl.float32) - l_i = tl.load(l_ptrs) - l_i = l_i.to(tl.float32) - acc = tl.load(O_block_ptr) - acc = acc.to(tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(tl.bfloat16) - # loop over k, v and update accumulator - lo = 0 - hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX - for start_n in range(lo, hi, BLOCK_N): - # -- load k, v -- - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if IS_CAUSAL: - qk = tl.where( - offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf") - ) - qk += tl.dot(q, k) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc += tl.dot(p.to(tl.bfloat16), v) - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - # write back original l and m - tl.store(m_ptrs, m_i) - tl.store(l_ptrs, l_i) - # write back O, L - if LAST_STEP: - acc = acc / l_i[:, None] - L_ptrs = L + off_hz * N_CTX + offs_m - tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) - tl.store(O_block_ptr, acc.to(tl.bfloat16)) - - -# for gqa/mqa to expand kv heads -def maybe_repeat_kv_fwd(nqh, kv): - bs, nkvh, slen, hdim = kv.shape - n_rep = nqh // nkvh - if n_rep == 1: - return kv - kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) - return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) - - -def maybe_repeat_kv_bwd(nqh, kv): - bs, slen, nkvh, hdim = kv.shape - n_rep = nqh // nkvh - if n_rep == 1: - return kv - kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) - return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) - - -# kv grad has shape bs, slen, nqh, hdim -def maybe_reduce_dkv(nkvh, dkv): - bs, slen, nqh, hdim = dkv.shape - n_rep = nqh // nkvh - if n_rep == 1: - return dkv - dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) - return torch.sum(dkv_reshape, dim=3) - - -def _lightseq_forward(q, k, v, causal, sm_scale, comm_mode): - # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - # Why do I have to change it from 128 64 to 32 32? - BLOCK_M = 32 - BLOCK_N = 32 - - bsz, nh, seq_len, hdim = q.shape - - m = torch.full( - (bsz * nh, seq_len), - fill_value=-float("inf"), - device=q.device, - dtype=torch.float32, - ) - l = torch.zeros_like(m) - L = torch.zeros_like(m) - o = torch.zeros_like(q) - - grid = (triton.cdiv(seq_len, BLOCK_M), bsz * nh, 1) - num_warps = 4 if Lk <= 64 else 8 - - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - # Initialize all buffers - peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer( - q, k, v, m, l, o - ) - - fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[ - grid - ]( - q, - k, - v, - sm_scale, - m, - l, - o, - L, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=Lk, - IS_CAUSAL=IS_CAUSAL, - LAST_STEP=LAST_STEP, - num_warps=num_warps, - num_stages=4, - ) - - for time_step in range(seq_world_size // 2 + 1): - # This is important for cuda scheduler to execute nccl calls first. - torch.cuda.synchronize() - # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. - buffer_idx_1 = time_step % 2 - buffer_idx_2 = (time_step - 1) % 2 - - reqs = maybe_send_recv_fwd_qkvo( - q, - peer_q[buffer_idx_1], - k, - peer_k[buffer_idx_1], - v, - peer_v[buffer_idx_1], - [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], - time_step, - comm_mode, - ) - if comm_mode == "sync": - # if seq_rank == 0: - # print("Immediate wait for abalation") - wait_async_handles(reqs) - if is_compute_for_local_query(time_step): - # print(f"t={time_step}: (Comp) R={seq_rank} local compute") - if time_step == 0: - fwd_launch_helper( - q, - maybe_repeat_kv_fwd(q.shape[1], k), - maybe_repeat_kv_fwd(q.shape[1], v), - m, - l, - o, - L, - True, - is_last_time(time_step), - ) - else: - # if needs to sync from others, do not normalize here - fwd_launch_helper( - q, - maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), - maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), - m, - l, - o, - L, - False, - not is_sync_from_remote(time_step) and is_last_time(time_step), - ) - elif is_idle(time_step): - # print(f"t={time_step}: (Comp) R={seq_rank} idle") - pass - else: - # print(f"t={time_step}: (Comp) R={seq_rank} helps other") - peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) - peer_l[buffer_idx_2] = torch.zeros_like(l) - peer_o[buffer_idx_2] = torch.zeros_like(o) - - # print(f"rank 3 q is: {peer_q[buffer_idx_2]}") - fwd_launch_helper( - peer_q[buffer_idx_2], - maybe_repeat_kv_fwd(q.shape[1], k), - maybe_repeat_kv_fwd(q.shape[1], v), - peer_m[buffer_idx_2], - peer_l[buffer_idx_2], - peer_o[buffer_idx_2], - None, - False, - False, - ) - - if comm_mode == "lightseq": - # Make sure tensors for next steps are ready - wait_async_handles(reqs) - # sync between statistics get from other ranks and the local ones - if is_sync_from_remote(time_step): - _rescale_kernel[grid]( - peer_m[buffer_idx_1], - m, - peer_l[buffer_idx_1], - l, - peer_o[buffer_idx_1], - o, - L, - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - o.shape[0], - o.shape[1], - o.shape[2], - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=Lk, - LAST_STEP=is_last_time(time_step), - num_warps=num_warps, - num_stages=4, - ) - return q, k, v, o, L - - -def _lightseq_backward(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine): - BLOCK = 128 - q, k, v, o, do = [ - rearrange(_x, "b h s d -> b s h d").contiguous() for _x in [q, k, v, o, do] - ] - L = rearrange(L, "(b h) s -> b h s", b=q.shape[0]) - - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - - # maybe gqa - nqh = q.shape[2] - nkvh = k.shape[2] - is_gqa = nqh > nkvh - - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - # Initialize all backward buffers - ( - dq_delta, - dk_delta, - dv_delta, - dk_delta_from_peer, - dv_delta_from_peer, - peer_q, - peer_L, - peer_k, - peer_v, - peer_o, - peer_do, - ) = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) - - for time_step in range(0, get_sequence_parallel_size() // 2 + 1): - torch.cuda.synchronize() - buffer_idx_1 = time_step % 2 - buffer_idx_2 = (time_step - 1) % 2 - - reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo( - dq_delta[buffer_idx_1], - dk_delta[buffer_idx_1], - dv_delta[buffer_idx_1], - dk_delta_from_peer, - dv_delta_from_peer, - q, - peer_q[buffer_idx_1], - L, - peer_L[buffer_idx_1], - k, - peer_k[buffer_idx_1], - v, - peer_v[buffer_idx_1], - o, - peer_o[buffer_idx_1], - do, - peer_do[buffer_idx_1], - time_step, - comm_mode, - ) - if comm_mode == "sync": - # if seq_rank == 0: - # print("(bwd) Immediate wait for abalation") - wait_async_handles(reqs) - - if is_compute_for_local_query(time_step): - if time_step == 0: - if backward_engine == "flash": - _flash_attn_backward( - do, - q, - k, - v, - o, - L, - dq, - dk, - dv, - 0.0, - sm_scale, - True, - (-1, -1), - None, - False, - ) - else: - inp = Inputs( - query=q, - key=maybe_repeat_kv_bwd(q.shape[2], k), - value=maybe_repeat_kv_bwd(q.shape[2], v), - attn_bias=xformers.ops.LowerTriangularMask(), - p=0, - scale=sm_scale, - ) - op_ctx = Context(lse=L, out=o, rng_state=None) - # Let xformers dispatch the correct backend - grads = _memory_efficient_attention_backward( - ctx=op_ctx, inp=inp, grad=do, op=None - ) - dq = grads.dq - dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv( - nkvh, grads.dv - ) - else: - if backward_engine == "flash": - _flash_attn_backward( - do, - q, - peer_k[buffer_idx_2], - peer_v[buffer_idx_2], - o, - L, - dq_delta[buffer_idx_2], - dk_delta[buffer_idx_2], - dv_delta[buffer_idx_2], - 0.0, - sm_scale, - False, - (-1, -1), - None, - False, - ) - else: - inp = Inputs( - query=q, - key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), - value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), - attn_bias=None, - p=0, - scale=sm_scale, - ) - op_ctx = Context(lse=L, out=o, rng_state=None) - grads = _memory_efficient_attention_backward( - ctx=op_ctx, inp=inp, grad=do, op=None - ) - dq_delta[buffer_idx_2] = grads.dq - dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( - nkvh, grads.dk - ), maybe_reduce_dkv(nkvh, grads.dv) - dq += dq_delta[buffer_idx_2] - elif is_idle(time_step): - pass - else: - if backward_engine == "flash": - _flash_attn_backward( - peer_do[buffer_idx_2], - peer_q[buffer_idx_2], - k, - v, - peer_o[buffer_idx_2], - peer_L[buffer_idx_2], - dq_delta[buffer_idx_2], - dk_delta[buffer_idx_2], - dv_delta[buffer_idx_2], - 0.0, - sm_scale, - False, - (-1, -1), - None, - False, - ) - else: - inp = Inputs( - query=peer_q[buffer_idx_2], - key=maybe_repeat_kv_bwd(q.shape[2], k), - value=maybe_repeat_kv_bwd(q.shape[2], v), - attn_bias=None, - p=0, - scale=sm_scale, - ) - op_ctx = Context( - lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None - ) - grads = _memory_efficient_attention_backward( - ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None - ) - dq_delta[buffer_idx_2] = grads.dq - dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( - nkvh, grads.dk - ), maybe_reduce_dkv(nkvh, grads.dv) - dk += dk_delta[buffer_idx_2] - dv += dv_delta[buffer_idx_2] - - if comm_mode == "lightseq": - # Make sure tensors for next steps are ready - wait_async_handles(reqs) - - # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. - reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv( - dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode - ) - - if comm_mode == "sync": - # if seq_rank == 0: - # print("(bwd) dkv Immediate wait for abalation") - wait_async_handles(reqs) - # apply dq_delta, dk_delta and dv_delta from remote - if is_update_dq: - dq += dq_delta[buffer_idx_1] - if is_update_dkv: - dk += dk_delta_from_peer - dv += dv_delta_from_peer - - if comm_mode == "lightseq": - wait_async_handles(reqs) - # apply dk_delta and dv_delta to sender - if is_update_last_dkv: - dk += dk_delta[buffer_idx_2] - dv += dv_delta[buffer_idx_2] - - dq, dk, dv = [rearrange(_x, "b h s d -> b s h d") for _x in [dq, dk, dv]] - return dq, dk, dv - - -class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, causal, sm_scale): - try: - global args - comm_mode = args.comm_mode - backward_engine = args.backward_engine - except: - comm_mode = "lightseq" - backward_engine = "flash" - - q, k, v, o, L = _lightseq_forward(q, k, v, causal, sm_scale, comm_mode) - - ctx.save_for_backward(q, k, v, o, L) - ctx.sm_scale = sm_scale - ctx.comm_mode = comm_mode - ctx.backward_engine = backward_engine - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, L = ctx.saved_tensors - sm_scale = ctx.sm_scale - - dq, dk, dv = _lightseq_backward( - do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine - ) - return dq, dk, dv, None, None - - -attention = _attention.apply - - -# @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) -def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): - torch.manual_seed(20) - q = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - - rank = dist.get_rank() - world_size = dist.get_world_size() - seq_per_rank = N_CTX // world_size - - sm_scale = 0.5 - dout = torch.randn_like(q) - # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - assert causal - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - - # triton implementation - - a, b, c, d = q.size() - real_q = ( - q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_k = ( - k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_v = ( - v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_do = ( - dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - - tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() - - # compare - assert torch.allclose( - ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_out, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails forward" - print(f" *** rank {rank} passes forward") - tri_out.backward(real_do) - tri_dv, real_v.grad = real_v.grad.clone(), None - tri_dk, real_k.grad = real_k.grad.clone(), None - tri_dq, real_q.grad = real_q.grad.clone(), None - assert torch.allclose( - ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dq, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails backward dq" - assert torch.allclose( - ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dk, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dk" # f" {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" - assert torch.allclose( - ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dv, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" - print(f"rank {rank} passes backward") - - -def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): - torch.manual_seed(177) - q = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - - rank = dist.get_rank() - world_size = dist.get_world_size() - seq_per_rank = N_CTX // world_size - - sm_scale = 0.5 - dout = torch.randn_like(q) - # torch reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) - ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) - p = torch.matmul(q, ref_k.transpose(2, 3)) * sm_scale - assert causal - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, ref_v) - ref_out.backward(dout) - ref_dv, v.grad = ref_v.grad.clone(), None - ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1, 2))).transpose(1, 2) - ref_dk, k.grad = ref_k.grad.clone(), None - ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1, 2))).transpose(1, 2) - ref_dq, q.grad = q.grad.clone(), None - - # flash reference - from flash_attn import flash_attn_func, flash_attn_qkvpacked_func - - flash_q = q.transpose(1, 2).clone().detach().requires_grad_(True) - flash_k = k.transpose(1, 2).clone().detach().requires_grad_(True) - flash_v = v.transpose(1, 2).clone().detach().requires_grad_(True) - flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) - flash_ref_out.backward(dout.transpose(1, 2)) - flash_ref_out = flash_ref_out.transpose(1, 2) - flash_ref_dv, v.grad = flash_v.grad.clone(), None - flash_ref_dv = flash_ref_dv.transpose(1, 2) - flash_ref_dk, k.grad = flash_k.grad.clone(), None - flash_ref_dk = flash_ref_dk.transpose(1, 2) - flash_ref_dq, q.grad = flash_q.grad.clone(), None - flash_ref_dq = flash_ref_dq.transpose(1, 2) - - # triton implementation - - a, b, c, d = q.size() - real_q = ( - q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_k = ( - k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, KVH, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_v = ( - v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, KVH, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_do = ( - dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - - tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() - - # compare - assert torch.allclose( - flash_ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_out, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails forward against flash" - print(f" *** rank {rank} passes forward") - tri_out.backward(real_do) - tri_dv, real_v.grad = real_v.grad.clone(), None - tri_dk, real_k.grad = real_k.grad.clone(), None - tri_dq, real_q.grad = real_q.grad.clone(), None - assert torch.allclose( - flash_ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dq, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails backward dq against flash" - # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) - assert torch.allclose( - flash_ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dk, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" - assert torch.allclose( - flash_ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dv, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" - print(f"rank {rank} passes backward against flash") - - assert torch.allclose( - ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_out, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails forward" - print(f" *** rank {rank} passes forward") - assert torch.allclose( - ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dq, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails backward dq" - # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) - assert torch.allclose( - ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dk, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" - assert torch.allclose( - ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dv, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" - print(f"rank {rank} passes backward") - - -# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 -try: - from flash_attn.flash_attn_interface import ( - flash_attn_qkvpacked_func as flash_attn_func, - ) - - FLASH_VER = 2 -except BaseException: - try: - from flash_attn.flash_attn_interface import flash_attn_func - - FLASH_VER = 1 - except BaseException: - FLASH_VER = None -HAS_FLASH = FLASH_VER is not None -HAS_FLASH = None -ONLY_FLASH = False - -# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 -BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 -# vary seq length for fixed head and batch=4 -configs = [ - triton.testing.Benchmark( - x_names=["N_CTX"], - x_vals=[ - 2**i for i in range(18, 19) - ], # [ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], - line_arg="provider", - line_vals=["triton"] - if not ONLY_FLASH - else [] + (["flash"] if HAS_FLASH else []), - line_names=["Triton"] - if not ONLY_FLASH - else [] + ([f"Flash-{FLASH_VER}"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-")], - ylabel="ms", - plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}", - args={ - "H": N_HEADS, - "BATCH": BATCH, - "D_HEAD": D_HEAD, - "dtype": torch.bfloat16, - "mode": mode, - "causal": causal, - }, - ) - for mode in ["all"] - for causal in [True] -] - - -# @triton.testing.perf_report(configs) -def bench_flash_attention( - BATCH, - H, - KVH, - N_CTX, - D_HEAD, - causal, - mode, - provider, - args, - dtype=torch.bfloat16, - device="cuda", -): - assert mode == "all" # mode in ['fwd', 'bwd'] - n_warmup = 10 - n_repeat = 10 - cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - if provider == "triton": - q = torch.randn( - (BATCH, H, N_CTX // seq_world_size, D_HEAD), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - k = torch.randn( - (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - v = torch.randn( - (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - if seq_rank == 0: - print(f"Benchmarking per GPU qkv shape: {q.shape}") - sm_scale = 1.3 - fwd_fn = lambda: attention(q, k, v, causal, sm_scale) - if provider == "flash": - qkv = torch.randn( - (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True - ) - if FLASH_VER == 1: - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) - cu_seqlens[1:] = lengths.cumsum(0) - qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) - fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0.0, N_CTX, causal=causal) - elif FLASH_VER == 2: - fwd_fn = lambda: flash_attn_func(qkv, causal=causal) - else: - raise ValueError(f"unknown {FLASH_VER = }") - - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size - attn_flops = 2 * flops_per_matmul - - assert causal - if causal: - attn_flops *= 0.5 - fwd_flops = attn_flops - bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) - - o = fwd_fn() - do = torch.randn_like(o) - bwd_fn = lambda: o.backward(do, retain_graph=True) - - def run_benchmark(fn): - time_list = [] - for _ in tqdm(range(n_warmup)): - cache.zero_() - fn() - torch.cuda.synchronize() - if args.debug: - print_and_reset_comm_stats() - for i in tqdm(range(n_repeat)): - cache.zero_() - torch.cuda.synchronize() - time_s = time.time() - fn() - torch.cuda.synchronize() - time_e = time.time() - time_list.append((time_e - time_s) * 1000.0) - if args.debug: - print_and_reset_comm_stats() - return np.asarray(time_list) - - fwd_time_arr = run_benchmark(fwd_fn) - bwd_time_arr = run_benchmark(bwd_fn) - - fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 - print( - f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n" - ) - - bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 - print( - f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n" - ) - - # total - total_time_arr = fwd_time_arr + bwd_time_arr - total_flops = fwd_flops + bwd_flops - total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 - print( - f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n" - ) - - # return total_flops_ps - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--comm-mode", type=str, default="lightseq") - parser.add_argument("--debug", action="store_true") - parser.add_argument("--run-mode", type=str, default="benchmark") - parser.add_argument("--bs", type=int, default=1) - parser.add_argument("--n_heads", type=int, default=32) - parser.add_argument("--n_kvheads", type=int, default=32) - parser.add_argument("--d_head", type=int, default=128) - parser.add_argument("--start_ctx", type=int, default=12) - parser.add_argument("--end_ctx", type=int, default=18) - parser.add_argument("--forward_engine", type=str, default="triton") - parser.add_argument("--backward_engine", type=str, default="flash") - - global args - args = parser.parse_args() - initialize_distributed() - - assert args.forward_engine == "triton", "Only triton forward is implmented." - assert args.backward_engine in [ - "flash", - "xformers", - ], "Only flash or xformers backward is implemented." - - if args.backward_engine == "flash": - from flash_attn.flash_attn_interface import ( - _flash_attn_backward, - _flash_attn_forward, - ) - else: - try: - import xformers.ops - from xformers.ops.fmha import ( - _memory_efficient_attention_backward, - cutlass, - flash, - ) - from xformers.ops.fmha.common import Context, Inputs - except ImportError: - print("xformers not found! Please install it before trying to use it.") - - if args.run_mode == "benchmark": - for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: - bench_flash_attention( - args.bs, - args.n_heads, - args.n_kvheads, - N_CTX, - args.d_head, - True, - "all", - "triton", - args, - ) # .run(save_path='.', print_data=True) - reset_global_memory_buffer() - else: - assert args.run_mode == "test" - for N_CTX in [2048, 4096]: - test_op(1, 16, N_CTX, 128, True) - # test_gqa(1, 16, 8, N_CTX, 128, True) - reset_global_memory_buffer() diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py b/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py deleted file mode 100644 index 2f1916903..000000000 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py +++ /dev/null @@ -1,1273 +0,0 @@ -import argparse -import math -import os -import time - -import numpy as np -import pytest -import torch -import torch.distributed as dist -import triton -import triton.language as tl -from einops import rearrange -from torch.distributed import ReduceOp -from tqdm import tqdm - -# from torch.profiler import profile, record_function, ProfilerActivity - - -try: - from flash_attn.flash_attn_interface import _flash_attn_varlen_backward -except: - pass - -from .async_communication import ( - get_sequence_parallel_rank, - get_sequence_parallel_size, - initialize_distributed, - is_compute_for_local_query, - is_idle, - is_last_time, - is_sync_from_remote, - launch_async_handles, - maybe_get_set_global_memory_buffer, - maybe_get_set_global_memory_buffer_bwd, - maybe_send_recv_bwd_last_dkv, - maybe_send_recv_bwd_qkvo, - maybe_send_recv_fwd_qkvo, - print_and_reset_comm_stats, - reset_global_memory_buffer, - wait_async_handles, -) - - -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) - - -@triton.jit -def _rescale_kernel( - peer_m, - m, - peer_l, - l, - peer_o, - o, - L, - stride_oz, - stride_oh, - stride_om, - stride_on, - Z, - H, - N_CTX, - seqlen_q_rounded, - seqlen_peer_q_rounded, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - LAST_STEP: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - o_offset = off_hz * stride_oh - peer_o_block_ptr = tl.make_block_ptr( - base=peer_o + o_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - o_block_ptr = tl.make_block_ptr( - base=o + o_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - - peer_m_ptrs = peer_m + off_hz * seqlen_peer_q_rounded + offs_m - m_ptrs = m + off_hz * seqlen_q_rounded + offs_m - peer_l_ptrs = peer_l + off_hz * seqlen_peer_q_rounded + offs_m - l_ptrs = l + off_hz * seqlen_q_rounded + offs_m - - peer_m_i = tl.load(peer_m_ptrs) - peer_m_i = peer_m_i.to(tl.float32) - m_i = tl.load(m_ptrs) - m_i = m_i.to(tl.float32) - peer_l_i = tl.load(peer_l_ptrs) - peer_l_i = peer_l_i.to(tl.float32) - l_i = tl.load(l_ptrs) - l_i = l_i.to(tl.float32) - - peer_acc = tl.load( - peer_o_block_ptr - ) # , boundary_check=(0, 1), padding_option='zero') - peer_acc = peer_acc.to(tl.float32) - acc = tl.load(o_block_ptr) # , boundary_check=(0, 1), padding_option='zero') - acc = acc.to(tl.float32) - lo = 0 - hi = N_CTX - m_i_sync = tl.maximum(m_i, peer_m_i) - alpha = tl.math.exp2(m_i - m_i_sync) - peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug - - acc *= acc_scale[:, None] - peer_acc *= peer_acc_scale[:, None] - acc += peer_acc - l_i = l_i * acc_scale + peer_l_i * peer_acc_scale - # write back O, l, m - tl.store(m_ptrs, m_i_sync) - tl.store(l_ptrs, l_i) - if LAST_STEP: - acc = acc / l_i[:, None] - L_ptrs = L + off_hz * N_CTX + offs_m - tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) - tl.store(o_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - m, - l, - O, - L, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - Z, - H, - N_CTX, - seqlen_q_rounded, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, - LAST_STEP: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qvk_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - O_block_ptr = tl.make_block_ptr( - base=O + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l -> load from provided pointer - # (TODO): Why float32? - m_ptrs = m + off_hz * seqlen_q_rounded + offs_m - l_ptrs = l + off_hz * seqlen_q_rounded + offs_m - m_i = tl.load(m_ptrs) - m_i = m_i.to(tl.float32) - l_i = tl.load(l_ptrs) - l_i = l_i.to(tl.float32) - acc = tl.load(O_block_ptr) - acc = acc.to(tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") - q = (q * qk_scale).to(tl.bfloat16) - # loop over k, v and update accumulator - lo = 0 - hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX - for start_n in range(lo, hi, BLOCK_N): - # -- load k, v -- - k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") - v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero") - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if IS_CAUSAL: - qk = tl.where( - offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf") - ) - qk += tl.dot(q, k) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc += tl.dot(p.to(tl.bfloat16), v) - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - # write back original l and m - tl.store(m_ptrs, m_i) - tl.store(l_ptrs, l_i) - # write back O, L - if LAST_STEP: - acc = acc / l_i[:, None] - L_ptrs = L + off_hz * seqlen_q_rounded + offs_m - tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) - tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) - - -# for gqa/mqa to expand kv heads -def maybe_repeat_kv_fwd(nqh, kv): - bs, nkvh, slen, hdim = kv.shape - n_rep = nqh // nkvh - if n_rep == 1: - return kv - kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) - return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) - - -def maybe_repeat_kv_bwd(nqh, kv): - bs, slen, nkvh, hdim = kv.shape - n_rep = nqh // nkvh - if n_rep == 1: - return kv - kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) - return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) - - -# kv grad has shape bs, slen, nqh, hdim -def maybe_reduce_dkv(nkvh, dkv): - bs, slen, nqh, hdim = dkv.shape - n_rep = nqh // nkvh - if n_rep == 1: - return dkv - # print("*"*100, dkv.shape, bs, slen, nkvh, n_rep, hdim) - dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) - # print("-"*100, dkv_reshape.shape, bs, slen, nkvh, n_rep, hdim) - return torch.sum(dkv_reshape, dim=3) - - -def _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode): - # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x - # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - # assert Lq == Lk and Lk == Lv - # assert Lk in {16, 32, 64, 128} - BLOCK_M = 128 - BLOCK_N = 64 - - bsz, nh, unpadded_seq_len, hdim = q.shape - cu_seq_lens = torch.arange( - 0, - (bsz + 1) * unpadded_seq_len, - unpadded_seq_len, - dtype=torch.int32, - device=q.device, - ) - max_seqlen = unpadded_seq_len - seqlen_q_rounded = math.ceil(q.shape[2] / BLOCK_M) * BLOCK_M - - m = torch.full( - (bsz * nh, seqlen_q_rounded), - fill_value=-float("inf"), - device=q.device, - dtype=torch.float32, - ) - l = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) - L = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) - o = torch.zeros_like(q) - - grid = (triton.cdiv(q.shape[2], BLOCK_M), bsz * nh, 1) - num_warps = 4 if Lk <= 64 else 8 - - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - # Initialize all buffers - peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer( - q, k, v, m, l, o - ) - - fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[ - grid - ]( - q, - k, - v, - sm_scale, - m, - l, - o, - L, - q.stride(0), - q.stride(1), - q.stride(2), - q.stride(3), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - q.shape[0], - q.shape[1], - q.shape[2], - seqlen_q_rounded, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=Lk, - IS_CAUSAL=IS_CAUSAL, - LAST_STEP=LAST_STEP, - num_warps=num_warps, - num_stages=4, - ) - - for time_step in range(seq_world_size // 2 + 1): - # This is important for cuda scheduler to execute nccl calls first. - torch.cuda.synchronize() - # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. - buffer_idx_1 = time_step % 2 - buffer_idx_2 = (time_step - 1) % 2 - - reqs = maybe_send_recv_fwd_qkvo( - q, - peer_q[buffer_idx_1], - k, - peer_k[buffer_idx_1], - v, - peer_v[buffer_idx_1], - [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], - time_step, - comm_mode, - ) - if comm_mode == "sync": - # if seq_rank == 0: - # print("Immediate wait for abalation") - wait_async_handles(reqs) - if is_compute_for_local_query(time_step): - # print(f"t={time_step}: (Comp) R={seq_rank} local compute") - if time_step == 0: - fwd_launch_helper( - q, - maybe_repeat_kv_fwd(q.shape[1], k), - maybe_repeat_kv_fwd(q.shape[1], v), - m, - l, - o, - L, - True, - is_last_time(time_step), - ) - else: - # if needs to sync from others, do not normalize here - fwd_launch_helper( - q, - maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), - maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), - m, - l, - o, - L, - False, - not is_sync_from_remote(time_step) and is_last_time(time_step), - ) - elif is_idle(time_step): - # print(f"t={time_step}: (Comp) R={seq_rank} idle") - pass - else: - # print(f"t={time_step}: (Comp) R={seq_rank} helps other") - peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) - peer_l[buffer_idx_2] = torch.zeros_like(l) - peer_o[buffer_idx_2] = torch.zeros_like(o) - - # print(f"rank 3 q is: {peer_q[buffer_idx_2]}") - fwd_launch_helper( - peer_q[buffer_idx_2], - maybe_repeat_kv_fwd(q.shape[1], k), - maybe_repeat_kv_fwd(q.shape[1], v), - peer_m[buffer_idx_2], - peer_l[buffer_idx_2], - peer_o[buffer_idx_2], - None, - False, - False, - ) - - if comm_mode == "lightseq": - # Make sure tensors for next steps are ready - wait_async_handles(reqs) - # sync between statistics get from other ranks and the local ones - if is_sync_from_remote(time_step): - # print(f"t={time_step}: (Comp) R={seq_rank} sync with other - last time: {is_last_time(time_step)}") - seqlen_peer_q_rounded = peer_l[buffer_idx_1].shape[-1] - _rescale_kernel[grid]( - peer_m[buffer_idx_1], - m, - peer_l[buffer_idx_1], - l, - peer_o[buffer_idx_1], - o, - L, - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), - o.shape[0], - o.shape[1], - o.shape[2], - seqlen_q_rounded, - seqlen_peer_q_rounded, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=Lk, - LAST_STEP=is_last_time(time_step), - num_warps=num_warps, - num_stages=4, - ) - return q, k, v, o, L, cu_seq_lens, max_seqlen - - -def _lightseq_backward_varlen( - do, q, k, v, o, L, sm_scale, comm_mode, backward_engine, cu_seq_lens, max_seqlen -): - BLOCK = 128 - L = rearrange(L[:, :max_seqlen].contiguous(), "(b h) s -> b h s", b=q.shape[0]) - q, k, v, o, do = [ - rearrange(_x, "b h s d -> (b s) h d").contiguous() for _x in [q, k, v, o, do] - ] - - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - - # maybe gqa - nqh = q.shape[1] - nkvh = k.shape[1] - is_gqa = nqh > nkvh - - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - - # Initialize all backward buffers - ( - dq_delta, - dk_delta, - dv_delta, - dk_delta_from_peer, - dv_delta_from_peer, - peer_q, - peer_L, - peer_k, - peer_v, - peer_o, - peer_do, - ) = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) - - for time_step in range(0, get_sequence_parallel_size() // 2 + 1): - torch.cuda.synchronize() - buffer_idx_1 = time_step % 2 - buffer_idx_2 = (time_step - 1) % 2 - - reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo( - dq_delta[buffer_idx_1], - dk_delta[buffer_idx_1], - dv_delta[buffer_idx_1], - dk_delta_from_peer, - dv_delta_from_peer, - q, - peer_q[buffer_idx_1], - L, - peer_L[buffer_idx_1], - k, - peer_k[buffer_idx_1], - v, - peer_v[buffer_idx_1], - o, - peer_o[buffer_idx_1], - do, - peer_do[buffer_idx_1], - time_step, - comm_mode, - ) - if comm_mode == "sync": - wait_async_handles(reqs) - - if is_compute_for_local_query(time_step): - if time_step == 0: - assert ( - backward_engine == "flash" - ), "We haven't supportted varlen feature in xformer" - if backward_engine == "flash": - _flash_attn_varlen_backward( - do, - q, - k, - v, - o, - L, - dq, - dk, - dv, - cu_seq_lens, - cu_seq_lens, - max_seqlen, - max_seqlen, - 0.0, - sm_scale, - True, - None, - ) - else: - inp = Inputs( - query=q, - key=maybe_repeat_kv_bwd(q.shape[2], k), - value=maybe_repeat_kv_bwd(q.shape[2], v), - attn_bias=xformers.ops.LowerTriangularMask(), - p=0, - scale=sm_scale, - ) - op_ctx = Context(lse=L, out=o, rng_state=None) - # Let xformers dispatch the correct backend - grads = _memory_efficient_attention_backward( - ctx=op_ctx, inp=inp, grad=do, op=None - ) - dq = grads.dq - dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv( - nkvh, grads.dv - ) - else: - assert ( - backward_engine == "flash" - ), "We haven't supportted varlen feature in xformer" - if backward_engine == "flash": - _flash_attn_varlen_backward( - do, - q, - peer_k[buffer_idx_2], - peer_v[buffer_idx_2], - o, - L, - dq_delta[buffer_idx_2], - dk_delta[buffer_idx_2], - dv_delta[buffer_idx_2], - cu_seq_lens, - cu_seq_lens, - max_seqlen, - max_seqlen, - 0.0, - sm_scale, - False, - None, - ) - else: - inp = Inputs( - query=q, - key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), - value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), - attn_bias=None, - p=0, - scale=sm_scale, - ) - op_ctx = Context(lse=L, out=o, rng_state=None) - grads = _memory_efficient_attention_backward( - ctx=op_ctx, inp=inp, grad=do, op=None - ) - dq_delta[buffer_idx_2] = grads.dq - dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( - nkvh, grads.dk - ), maybe_reduce_dkv(nkvh, grads.dv) - dq += dq_delta[buffer_idx_2] - elif is_idle(time_step): - # print(f"BWD t={time_step}: (Comp) R={seq_rank} idle") - pass - else: - # print(f"BWD t={time_step}: (Comp) R={seq_rank} helps other") - assert ( - backward_engine == "flash" - ), "We haven't supportted varlen feature in xformer" - if backward_engine == "flash": - _flash_attn_varlen_backward( - peer_do[buffer_idx_2], - peer_q[buffer_idx_2], - k, - v, - peer_o[buffer_idx_2], - peer_L[buffer_idx_2], - dq_delta[buffer_idx_2], - dk_delta[buffer_idx_2], - dv_delta[buffer_idx_2], - cu_seq_lens, - cu_seq_lens, - max_seqlen, - max_seqlen, - 0.0, - sm_scale, - False, - None, - ) - else: - inp = Inputs( - query=peer_q[buffer_idx_2], - key=maybe_repeat_kv_bwd(q.shape[2], k), - value=maybe_repeat_kv_bwd(q.shape[2], v), - attn_bias=None, - p=0, - scale=sm_scale, - ) - op_ctx = Context( - lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None - ) - grads = _memory_efficient_attention_backward( - ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None - ) - dq_delta[buffer_idx_2] = grads.dq - dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv( - nkvh, grads.dk - ), maybe_reduce_dkv(nkvh, grads.dv) - dk += dk_delta[buffer_idx_2] - dv += dv_delta[buffer_idx_2] - - if comm_mode == "lightseq": - # Make sure tensors for next steps are ready - wait_async_handles(reqs) - - # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. - reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv( - dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode - ) - - if comm_mode == "sync": - # if seq_rank == 0: - # print("(bwd) dkv Immediate wait for abalation") - wait_async_handles(reqs) - # apply dq_delta, dk_delta and dv_delta from remote - if is_update_dq: - dq += dq_delta[buffer_idx_1] - if is_update_dkv: - dk += dk_delta_from_peer - dv += dv_delta_from_peer - - if comm_mode == "lightseq": - wait_async_handles(reqs) - # apply dk_delta and dv_delta to sender - if is_update_last_dkv: - dk += dk_delta[buffer_idx_2] - dv += dv_delta[buffer_idx_2] - - dq, dk, dv = [ - rearrange(_x, "(b s) h d -> b h s d", s=max_seqlen) for _x in [dq, dk, dv] - ] - return dq, dk, dv - - -class _attention_varlen(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, causal, sm_scale): - try: - global args - comm_mode = args.comm_mode - backward_engine = args.backward_engine - except: - comm_mode = "lightseq" - backward_engine = "flash" - - q, k, v, o, L, cu_seq_lens, max_seqlen = _lightseq_forward_varlen( - q, k, v, causal, sm_scale, comm_mode - ) - - ctx.save_for_backward(q, k, v, o, L, cu_seq_lens) - ctx.max_seqlen = max_seqlen - ctx.sm_scale = sm_scale - ctx.comm_mode = comm_mode - ctx.backward_engine = backward_engine - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, L, cu_seq_lens = ctx.saved_tensors - sm_scale = ctx.sm_scale - max_seqlen = ctx.max_seqlen - - dq, dk, dv = _lightseq_backward_varlen( - do, - q, - k, - v, - o, - L, - sm_scale, - ctx.comm_mode, - ctx.backward_engine, - cu_seq_lens, - max_seqlen, - ) - return dq, dk, dv, None, None - - -dist_attn_varlen = _attention_varlen.apply - - -# @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) -def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): - torch.manual_seed(20) - rank = dist.get_rank() - world_size = dist.get_world_size() - - PAD = world_size * 256 - seq_per_rank = (N_CTX - PAD) // world_size - q = ( - torch.empty((Z, H, N_CTX - PAD, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, H, N_CTX - PAD, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, H, N_CTX - PAD, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - - # DEBUG: mask out - # mask = torch.zeros(Z, H, seq_per_rank * (world_size - 1), D_HEAD).cuda() - # mask_2 = torch.ones(Z, H, seq_per_rank, D_HEAD).cuda() - # mask = torch.cat((mask, mask_2), dim=-2).to(dtype) - # q = mask * q - # k = mask * k - # v = mask * v - - sm_scale = 0.5 - dout = torch.randn_like(q) - # reference implementation - M = torch.tril(torch.ones((N_CTX - PAD, N_CTX - PAD), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - assert causal - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - - # triton implementation - - a, b, c, d = q.size() - real_q = ( - q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_k = ( - k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_v = ( - v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_do = ( - dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - - tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() - - # compare - assert torch.allclose( - ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_out, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails forward" - print(f" *** rank {rank} passes forward") - tri_out.backward(real_do) - tri_dv, real_v.grad = real_v.grad.clone(), None - tri_dk, real_k.grad = real_k.grad.clone(), None - tri_dq, real_q.grad = real_q.grad.clone(), None - assert torch.allclose( - ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dq, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dq" # {ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dq} {torch.max(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dq)} rank {rank} fails backward dk" - assert torch.allclose( - ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dk, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dk" # {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" - assert torch.allclose( - ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dv, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dv" # {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" - print(f"rank {rank} passes backward") - - -# TODO(High Priority): Investigate why rank 0 tends to have larger numerical difference. -def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): - torch.manual_seed(177) - q = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - - rank = dist.get_rank() - world_size = dist.get_world_size() - seq_per_rank = N_CTX // world_size - - sm_scale = 0.5 - dout = torch.randn_like(q) - # torch reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) - ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) - # print(q.shape, ref_k.shape, k.shape) - p = torch.matmul(q, ref_k.transpose(2, 3)) * sm_scale - assert causal - if causal: - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, ref_v) - ref_out.backward(dout) - ref_dv, v.grad = ref_v.grad.clone(), None - # print("Before reduce", ref_dv.shape) - ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1, 2))).transpose(1, 2) - # print("After reduce", ref_dv.shape) - ref_dk, k.grad = ref_k.grad.clone(), None - ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1, 2))).transpose(1, 2) - ref_dq, q.grad = q.grad.clone(), None - - # flash reference - from flash_attn import flash_attn_func, flash_attn_qkvpacked_func - - flash_q = q.transpose(1, 2).clone().detach().requires_grad_(True) - flash_k = k.transpose(1, 2).clone().detach().requires_grad_(True) - flash_v = v.transpose(1, 2).clone().detach().requires_grad_(True) - flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) - flash_ref_out.backward(dout.transpose(1, 2)) - flash_ref_out = flash_ref_out.transpose(1, 2) - flash_ref_dv, v.grad = flash_v.grad.clone(), None - flash_ref_dv = flash_ref_dv.transpose(1, 2) - flash_ref_dk, k.grad = flash_k.grad.clone(), None - flash_ref_dk = flash_ref_dk.transpose(1, 2) - flash_ref_dq, q.grad = flash_q.grad.clone(), None - flash_ref_dq = flash_ref_dq.transpose(1, 2) - - # triton implementation - - a, b, c, d = q.size() - real_q = ( - q[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_k = ( - k[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, KVH, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_v = ( - v[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, KVH, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - real_do = ( - dout[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :] - .view(a, b, -1, d) - .contiguous() - .clone() - .detach() - .requires_grad_(True) - ) - - tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() - - # compare - assert torch.allclose( - flash_ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_out, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails forward against flash" - print(f" *** rank {rank} passes forward") - tri_out.backward(real_do) - tri_dv, real_v.grad = real_v.grad.clone(), None - tri_dk, real_k.grad = real_k.grad.clone(), None - tri_dq, real_q.grad = real_q.grad.clone(), None - assert torch.allclose( - flash_ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dq, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails backward dq against flash" - # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) - assert torch.allclose( - flash_ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dk, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" - assert torch.allclose( - flash_ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dv, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" - print(f"rank {rank} passes backward against flash") - - assert torch.allclose( - ref_out[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_out, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails forward" - print(f" *** rank {rank} passes forward") - assert torch.allclose( - ref_dq[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dq, - atol=1e-2, - rtol=0, - ), f" rank {rank} fails backward dq" - # print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) - assert torch.allclose( - ref_dk[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dk, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" - assert torch.allclose( - ref_dv[:, :, rank * seq_per_rank : (rank + 1) * seq_per_rank, :], - tri_dv, - atol=1e-2, - rtol=0, - ), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" - print(f"rank {rank} passes backward") - - -# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 -try: - from flash_attn.flash_attn_interface import ( - flash_attn_qkvpacked_func as flash_attn_func, - ) - - FLASH_VER = 2 -except BaseException: - try: - from flash_attn.flash_attn_interface import flash_attn_func - - FLASH_VER = 1 - except BaseException: - FLASH_VER = None -HAS_FLASH = FLASH_VER is not None -HAS_FLASH = None -ONLY_FLASH = False - -# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 -BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 -# vary seq length for fixed head and batch=4 -configs = [ - triton.testing.Benchmark( - x_names=["N_CTX"], - x_vals=[ - 2**i for i in range(18, 19) - ], # [ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], - line_arg="provider", - line_vals=["triton"] - if not ONLY_FLASH - else [] + (["flash"] if HAS_FLASH else []), - line_names=["Triton"] - if not ONLY_FLASH - else [] + ([f"Flash-{FLASH_VER}"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-")], - ylabel="ms", - plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}", - args={ - "H": N_HEADS, - "BATCH": BATCH, - "D_HEAD": D_HEAD, - "dtype": torch.bfloat16, - "mode": mode, - "causal": causal, - }, - ) - for mode in ["all"] - for causal in [True] -] - - -# @triton.testing.perf_report(configs) -def bench_flash_attention( - BATCH, - H, - KVH, - N_CTX, - D_HEAD, - causal, - mode, - provider, - args, - dtype=torch.bfloat16, - device="cuda", -): - assert mode == "all" # mode in ['fwd', 'bwd'] - n_warmup = 10 - n_repeat = 10 - cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") - seq_rank = get_sequence_parallel_rank() - seq_world_size = get_sequence_parallel_size() - if provider == "triton": - q = torch.randn( - (BATCH, H, N_CTX // seq_world_size, D_HEAD), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - k = torch.randn( - (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - v = torch.randn( - (BATCH, KVH, N_CTX // seq_world_size, D_HEAD), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - if seq_rank == 0: - print(f"Benchmarking per GPU qkv shape: {q.shape}") - sm_scale = 1.3 - fwd_fn = lambda: dist_attn_varlen(q, k, v, causal, sm_scale) - if provider == "flash": - qkv = torch.randn( - (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True - ) - if FLASH_VER == 1: - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) - cu_seqlens[1:] = lengths.cumsum(0) - qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) - fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0.0, N_CTX, causal=causal) - elif FLASH_VER == 2: - fwd_fn = lambda: flash_attn_func(qkv, causal=causal) - else: - raise ValueError(f"unknown {FLASH_VER = }") - - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size - attn_flops = 2 * flops_per_matmul - - assert causal - if causal: - attn_flops *= 0.5 - fwd_flops = attn_flops - bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) - - o = fwd_fn() - do = torch.randn_like(o) - bwd_fn = lambda: o.backward(do, retain_graph=True) - - def run_benchmark(fn): - time_list = [] - for _ in tqdm(range(n_warmup)): - cache.zero_() - fn() - torch.cuda.synchronize() - if args.debug: - print_and_reset_comm_stats() - for i in tqdm(range(n_repeat)): - cache.zero_() - torch.cuda.synchronize() - time_s = time.time() - fn() - torch.cuda.synchronize() - time_e = time.time() - time_list.append((time_e - time_s) * 1000.0) - if args.debug: - print_and_reset_comm_stats() - return np.asarray(time_list) - - fwd_time_arr = run_benchmark(fwd_fn) - bwd_time_arr = run_benchmark(bwd_fn) - - fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 - print( - f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n" - ) - - bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 - print( - f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n" - ) - - # total - total_time_arr = fwd_time_arr + bwd_time_arr - total_flops = fwd_flops + bwd_flops - total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 - print( - f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n" - ) - - # return total_flops_ps - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--comm-mode", type=str, default="lightseq") - parser.add_argument("--debug", action="store_true") - parser.add_argument("--run-mode", type=str, default="test") - parser.add_argument("--bs", type=int, default=1) - parser.add_argument("--n_heads", type=int, default=32) - parser.add_argument("--n_kvheads", type=int, default=32) - parser.add_argument("--d_head", type=int, default=128) - parser.add_argument("--start_ctx", type=int, default=12) - parser.add_argument("--end_ctx", type=int, default=18) - parser.add_argument("--forward_engine", type=str, default="triton") - parser.add_argument("--backward_engine", type=str, default="flash") - - global args - args = parser.parse_args() - initialize_distributed() - - assert args.forward_engine == "triton", "Only triton forward is implmented." - assert args.backward_engine in [ - "flash", - "xformers", - ], "Only flash or xformers backward is implemented." - - if args.backward_engine == "flash": - from flash_attn.flash_attn_interface import ( - _flash_attn_backward, - _flash_attn_forward, - ) - else: - try: - import xformers.ops - from xformers.ops.fmha import ( - _memory_efficient_attention_backward, - cutlass, - flash, - ) - from xformers.ops.fmha.common import Context, Inputs - except ImportError: - print("xformers not found! Please install it before trying to use it.") - - if args.run_mode == "benchmark": - for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: - bench_flash_attention( - args.bs, - args.n_heads, - args.n_kvheads, - N_CTX, - args.d_head, - True, - "all", - "triton", - args, - ) # .run(save_path='.', print_data=True) - reset_global_memory_buffer() - else: - assert args.run_mode == "test" - for N_CTX in [4096]: - test_op(2, 16, N_CTX, 128, True) - # test_gqa(1, 16, 8, N_CTX, 128, True) - reset_global_memory_buffer() diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py deleted file mode 100644 index 927671909..000000000 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py +++ /dev/null @@ -1,754 +0,0 @@ -""" -Materialization-aware gradient checkpointing monkey patch. -""" -from typing import List, Optional, Tuple - -import torch -from einops import rearrange -from torch.utils.checkpoint import ( - _get_autocast_kwargs, - check_backward_validity, - detach_variable, - get_device_states, - set_device_states, -) -from transformers.models.llama.modeling_llama import ( - BaseModelOutputWithPast, - LlamaDecoderLayer, - LlamaModel, - apply_rotary_pos_emb, -) - -from .async_communication import initialize_distributed -from .lightseq_async_attn import _lightseq_backward, _lightseq_forward - -# define a global buffer to save flash attention outputs -# it's called global because it saves the outputs for all layers -global_flash_attn_out_buffer = None - -# define a local buffer to save recomputed qkv -# it's called local because it's a temporary buffer which will be updated across layers -local_res_grad_buffer = None - -# hooks for the gradients of residual -global_hooks = [] - - -def init_flash_attn_buffers(num_layers): - # update the global buffer according to number of layers - global global_flash_attn_out_buffer - global_flash_attn_out_buffer = [None] * num_layers - - -def clean_hook(): - # Remove all hooks in the global buffer - for hook in global_hooks: - hook.remove() - # Clear the global buffer - global_hooks.clear() - - -def clear_all_buffers_at_the_end_of_training(): - # call it at the end of training - global lobal_flash_attn_out_buffer - global_flash_attn_out_buffer = None - global local_res_grad_buffer - local_res_grad_buffer = None - clean_hook() - - -def save_flash_attn_out_to_global_buffer(idx, out): - global global_flash_attn_out_buffer - global_flash_attn_out_buffer[idx] = out - - -def get_flash_attn_out_from_global_buffer(idx): - global global_flash_attn_out_buffer - return global_flash_attn_out_buffer[idx] - - -def free_flash_attn_out_buffer(idx): - global global_flash_attn_out_buffer - global_flash_attn_out_buffer[idx] = None - - -def write_gradient_to_flash_attn_out(idx, grad): - global global_flash_attn_out_buffer - global_flash_attn_out_buffer[idx].grad = grad - - -def save_res_grad_hook(grad): - global local_res_grad_buffer - local_res_grad_buffer = grad - - -def load_and_add_res_grad_hook(grad): - grad += get_res_grad_from_local_buffer() - - -def get_res_grad_from_local_buffer(): - global local_res_grad_buffer - assert local_res_grad_buffer is not None - return local_res_grad_buffer - - -class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function): - """Avoid doing twice flash attention forward during checkpointed backward. - args: - hidden_states, # i.e., flash attention output which is saved in global buffer. - attention_mask, - position_ids, - residual, # the gradient of residual is saved in local buffer to pass across ckpt layers. - """ - - @staticmethod - def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): - check_backward_validity(args) - ctx.run_function = run_function - ctx.layer_idx = layer_idx - ctx.preserve_rng_state = preserve_rng_state - # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. - ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() - if preserve_rng_state: - ctx.fwd_cpu_state = torch.get_rng_state() - # Don't eagerly initialize the cuda context by accident. - # (If the user intends that the context is initialized later, within their - # run_function, we SHOULD actually stash the cuda state here. Unfortunately, - # we have no way to anticipate this will happen before we run the function.) - ctx.had_cuda_in_fwd = False - if torch.cuda._initialized: - ctx.had_cuda_in_fwd = True - ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) - - # Save non-tensor inputs in ctx, keep a placeholder None for tensors - # to be filled out during the backward. - ctx.inputs = [] - ctx.tensor_indices = [] - tensor_inputs = [] - for i, arg in enumerate(args): - if i == 0 and ctx.layer_idx != 0: - # flash attention output is saved to the global buffer during forward - ctx.inputs.append(None) - else: - if torch.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - with torch.no_grad(): - q, k, v, residual = run_function(*args) - softmax_scale = q.shape[-1] ** (-0.5) - - # lightseq version - _, _, _, out, softmax_lse = _lightseq_forward( - q, k, v, True, softmax_scale, comm_mode="lightseq" - ) - rng_state = None - - # save flash attention output to global buffer - save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) - tensor_inputs += [softmax_lse] - ctx.softmax_scale = softmax_scale - - ctx.save_for_backward(*tensor_inputs) - - return out, residual - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument." - ) - # Copy the list to avoid modifying original list. - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - tensors, softmax_lse = tensors[:-1], tensors[-1] - - # Fill in inputs with appropriate saved tensors. - # Fill the flash attention output first - if ctx.layer_idx > 0: - # inputs[0] should be flash attention output - inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx - 1) - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - - # Stash the surrounding rng state, and mimic the state that was - # present at this time during forward. Restore the surrounding state - # when we're done. - rng_devices = [] - if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: - rng_devices = ctx.fwd_gpu_devices - with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): - if ctx.preserve_rng_state: - torch.set_rng_state(ctx.fwd_cpu_state) - if ctx.had_cuda_in_fwd: - set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) - detached_inputs = detach_variable(tuple(inputs)) - with torch.enable_grad(), torch.cuda.amp.autocast( - **ctx.gpu_autocast_kwargs - ), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): - # Stop recomputation before flash attention - # It is unecessary to run recomputation for flash attn - q, k, v, residual = ctx.run_function(*detached_inputs) - - # run backward() with only tensor that requires grad - # run flash attention backward first: - # get 'dout' from auto_grad inputs - # get 'out' from global buffer - # get 'qkv' from the recomputed tensors - # dq = torch.empty(q.shape, dtype=q.dtype, device=q.device) - # dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) - # dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) - out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) - # todo get dout - dout = args[0] - - # lightseq version - dq, dk, dv = _lightseq_backward( - dout, - q, - k, - v, - out, - softmax_lse, - ctx.softmax_scale, - comm_mode="lightseq", - backward_engine="flash", - ) - # dqkv = torch.stack([dq, dk, dv]) - - # run backward for the part before flash attention - # qkv.backward(dqkv) - torch.autograd.backward([q, k, v], [dq, dk, dv]) - - grads = tuple( - inp.grad if isinstance(inp, torch.Tensor) else None - for inp in detached_inputs - ) - - # write flash attention output gradients to buffer - if ctx.layer_idx > 0: - write_gradient_to_flash_attn_out(ctx.layer_idx - 1, detached_inputs[0].grad) - - return (None, None, None) + grads - - -def checkpoint_end_with_flash_attention( - function, layer_idx, *args, use_reentrant: bool = True, **kwargs -): - # Hack to mix *args with **kwargs in a python 2.7-compliant way - preserve = kwargs.pop("preserve_rng_state", True) - if kwargs and use_reentrant: - raise ValueError( - "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) - ) - - return CheckpointFunctionEndWithFlashAttention.apply( - function, layer_idx, preserve, *args - ) - - -class CheckpointFunctionLastModule(torch.autograd.Function): - """ - for the last ffn layer after flash attention, modifications include: - write the gradients wrt flash attention output and residual to the global buffer. - """ - - @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): - check_backward_validity(args) - ctx.run_function = run_function - ctx.preserve_rng_state = preserve_rng_state - # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. - ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() - if preserve_rng_state: - ctx.fwd_cpu_state = torch.get_rng_state() - # Don't eagerly initialize the cuda context by accident. - # (If the user intends that the context is initialized later, within their - # run_function, we SHOULD actually stash the cuda state here. Unfortunately, - # we have no way to anticipate this will happen before we run the function.) - ctx.had_cuda_in_fwd = False - if torch.cuda._initialized: - ctx.had_cuda_in_fwd = True - ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) - - # Save non-tensor inputs in ctx, keep a placeholder None for tensors - # to be filled out during the backward. - ctx.inputs = [] - ctx.tensor_indices = [] - tensor_inputs = [] - - assert torch.is_tensor( - args[0] - ), "assuming the first tensor is the flash attention output" - for i, arg in enumerate(args): - if torch.is_tensor(arg) and i == 0: - # flash attn output has been saved to global buffer - ctx.inputs.append(None) - elif torch.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - - with torch.no_grad(): - outputs = run_function(*args) - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument." - ) - # Copy the list to avoid modifying original list. - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - - # Fill in inputs with appropriate saved tensors. - # Fill the flash attention output first - # inputs[0] should be flash attention output - inputs[0] = get_flash_attn_out_from_global_buffer(-1) - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - - # Stash the surrounding rng state, and mimic the state that was - # present at this time during forward. Restore the surrounding state - # when we're done. - rng_devices = [] - if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: - rng_devices = ctx.fwd_gpu_devices - with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): - if ctx.preserve_rng_state: - torch.set_rng_state(ctx.fwd_cpu_state) - if ctx.had_cuda_in_fwd: - set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) - detached_inputs = detach_variable(tuple(inputs)) - with torch.enable_grad(), torch.cuda.amp.autocast( - **ctx.gpu_autocast_kwargs - ), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - - # run backward() with only tensor that requires grad - outputs_with_grad = [] - args_with_grad = [] - for i in range(len(outputs)): - if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: - outputs_with_grad.append(outputs[i]) - args_with_grad.append(args[i]) - if len(outputs_with_grad) == 0: - raise RuntimeError( - "none of output has requires_grad=True," - " this checkpoint() is not necessary" - ) - torch.autograd.backward(outputs_with_grad, args_with_grad) - grads = tuple( - inp.grad if isinstance(inp, torch.Tensor) else None - for inp in detached_inputs - ) - - # write flash attention output gradients to buffer - write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) - - return (None, None) + grads - - -def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs): - preserve = kwargs.pop("preserve_rng_state", True) - if kwargs and use_reentrant: - raise ValueError( - "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) - ) - - return CheckpointFunctionLastModule.apply(function, preserve, *args) - - -def llama_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - compute_attn_only: Optional[bool] = False, - compute_ffn_only: Optional[bool] = False, - residual: Optional[bool] = None, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - assert compute_ffn_only or compute_attn_only - - if compute_attn_only: - residual = hidden_states - - if residual.requires_grad: - # register a hook to add the gradient of residual - # from next checkpoint layer when doing recomputation - hook = residual.register_hook(load_and_add_res_grad_hook) - global_hooks.append(hook) - - hidden_states = self.input_layernorm(hidden_states) - - # Flash Attention - bsz, q_len, _ = hidden_states.size() - try: - query_states = ( - self.self_attn.q_proj(hidden_states) - .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) - .transpose(1, 2) - ) - key_states = ( - self.self_attn.k_proj(hidden_states) - .view( - bsz, - q_len, - self.self_attn.num_key_value_heads, - self.self_attn.head_dim, - ) - .transpose(1, 2) - ) - value_states = ( - self.self_attn.v_proj(hidden_states) - .view( - bsz, - q_len, - self.self_attn.num_key_value_heads, - self.self_attn.head_dim, - ) - .transpose(1, 2) - ) - except: - # old transformers versions don't support num_key_value_heads - query_states = ( - self.self_attn.q_proj(hidden_states) - .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) - .transpose(1, 2) - ) - key_states = ( - self.self_attn.k_proj(hidden_states) - .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.self_attn.v_proj(hidden_states) - .view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim) - .transpose(1, 2) - ) - - kv_seq_len = key_states.shape[-2] - assert past_key_value is None, "past_key_value is not supported" - - cos, sin = self.self_attn.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - # [bsz, nh, t, hd] - assert not output_attentions, "output_attentions is not supported" - assert not use_cache, "use_cache is not supported" - return ( - query_states.contiguous(), - key_states.contiguous(), - value_states.contiguous(), - residual, - ) - - elif compute_ffn_only: - hidden_states = self.self_attn.o_proj( - rearrange(hidden_states, "b h s d -> b s (h d)") - ) - # Need to add residual here to make sure checkpoint is right after attention - if residual.requires_grad: - # save the gradient of residual to the local buffer - # collect the hooks which should be removed after backward to avoid memory leak - hook = residual.register_hook(save_res_grad_hook) - global_hooks.append(hook) - - hidden_states = residual + hidden_states - - # Fully Connected - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - else: - raise AttributeError - - return outputs - - -def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, -): - assert cache_position is None, "cache_position is not supported" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - attention_mask = None - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - try: - logger.warning_once("***** Using fast gradient checkpointing... *****") - except: - pass - # initialize the global buffer - init_flash_attn_buffers(len(self.layers)) - - if use_cache: - try: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - except: - pass - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # apply flash-attention friendly gradient checkpointing - if self.gradient_checkpointing and self.training: - for idx in range(len(self.layers) + 1): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - def forward_first_attn_module(module): - def custom_forward(*inputs): - hidden_states, attention_mask, position_ids, _ = inputs - # None for past_key_value - return module( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - compute_attn_only=True, - ) - - return custom_forward - - def forward_ffn_attn_layer(module1, module2): - def custom_forward(*inputs): - hidden_states, attention_mask, position_ids, residual = inputs - # None for past_key_value - layer_outputs = module1( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - compute_ffn_only=True, - residual=residual, - ) - hidden_states = layer_outputs[0] - return module2( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - compute_attn_only=True, - ) - - return custom_forward - - def forward_last_ffn_module(module): - def custom_forward(*inputs): - hidden_states, attention_mask, position_ids, residual = inputs - # None for past_key_value - return module( - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - compute_ffn_only=True, - residual=residual, - ) - - return custom_forward - - if idx == 0: - layer_outputs = checkpoint_end_with_flash_attention( - forward_first_attn_module(self.layers[0]), - idx, - hidden_states, - attention_mask, - position_ids, - None, - ) - hidden_states, residual = layer_outputs[0], layer_outputs[-1] - elif idx == len(self.layers): - layer_outputs = checkpoint_last_module( - forward_last_ffn_module(self.layers[-1]), - hidden_states, - attention_mask, - position_ids, - residual, - ) - hidden_states = layer_outputs[0] - else: - layer_outputs = checkpoint_end_with_flash_attention( - forward_ffn_attn_layer(self.layers[idx - 1], self.layers[idx]), - idx, - hidden_states, - attention_mask, - position_ids, - residual, - ) - hidden_states, residual = layer_outputs[0], layer_outputs[-1] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - else: - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -def apply_dist_flash_attn_monkey_patch_llama(): - initialize_distributed() - - LlamaModel.forward = forward - LlamaDecoderLayer.forward = llama_layer_forward diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py b/src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py deleted file mode 100644 index 64245a143..000000000 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/prepare_input.py +++ /dev/null @@ -1,34 +0,0 @@ -def extract_local(value, rank, world_size, device, dim=1): - value_local = value.chunk(world_size, dim=dim)[rank] - return value_local.to(device) - - -def prepare_dist_flash_attn_inputs( - input_ids, position_ids, target_ids, rank, world_size, device -): - local_input_ids = extract_local( - input_ids, - rank, - world_size, - device, - ) - local_position_ids = extract_local( - position_ids, - rank, - world_size, - device, - ) - if target_ids is not None: - local_target_ids = extract_local( - target_ids, - rank, - world_size, - device, - ) - else: - local_target_ids = None - return { - "local_input_ids": local_input_ids, - "local_position_ids": local_position_ids, - "local_target_ids": local_target_ids, - } diff --git a/src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py deleted file mode 100644 index 8221f6f91..000000000 --- a/src/axolotl/integrations/easy_context/ulysses_attn/monkey_patch.py +++ /dev/null @@ -1,114 +0,0 @@ -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -import transformers - -try: - from yunchang.ulysses import UlyssesAttention - - ulysses_attn = UlyssesAttention() -except: - print( - "If you want to use the UlyssesAttention class, please install the yunchang package." - ) - ulysses_attn = None - - -def new_flash_attn_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, -): - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - assert attention_mask is None - assert causal is True - assert use_sliding_windows is False - attn_output = ulysses_attn( - query_states, - key_states, - value_states, - dropout, - softmax_scale, - causal=causal, - ) - - return attn_output - - -def new_decoder_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - assert isinstance( - self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 - ) or isinstance( - self.self_attn, - transformers.models.mistral.modeling_mistral.MistralFlashAttention2, - ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." - - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -def apply_ulysses_attn_monkey_patch_llama(): - transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - new_flash_attn_forward - ) - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward - ) diff --git a/src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py b/src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py deleted file mode 100644 index b05716d0c..000000000 --- a/src/axolotl/integrations/easy_context/ulysses_attn/prepare_inputs.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch - - -def extract_local(value, rank, world_size, device, dim=1): - dimension_size = value.shape[dim] - sub_seq_length = dimension_size // world_size - - sub_seq_start = rank * sub_seq_length - sub_seq_end = (rank + 1) * sub_seq_length - local_value = value[:, sub_seq_start:sub_seq_end] - - return local_value.to(device) - - -def prepare_ulysses_attn_inputs( - input_ids, position_ids, target_ids, rank, world_size, device -): - local_input_ids = extract_local( - input_ids, - rank, - world_size, - device, - ) - local_position_ids = extract_local( - position_ids, - rank, - world_size, - device, - ) - - if target_ids is not None: - local_target_ids = extract_local( - target_ids, - rank, - world_size, - device, - ) - else: - local_target_ids = None - return { - "local_input_ids": local_input_ids, - "local_position_ids": local_position_ids, - "local_target_ids": local_target_ids, - } diff --git a/src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py b/src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py deleted file mode 100644 index 5d5d6a5a7..000000000 --- a/src/axolotl/integrations/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect - -import torch -import transformers - - -class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): - """ - Saves VRAM by smartly offloading to RAM. - Tiny hit to performance, since we mask the movement via non blocking calls. - """ - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, forward_function, hidden_states, *args): - saved_hidden_states = hidden_states.to("cpu", non_blocking=True) - with torch.no_grad(): - output = forward_function(hidden_states, *args) - ctx.save_for_backward(saved_hidden_states) - ctx.forward_function = forward_function - ctx.args = args - - return output - - pass - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, dY): - (hidden_states,) = ctx.saved_tensors - hidden_states = hidden_states.to("cuda", non_blocking=True).detach() - hidden_states.requires_grad = True - with torch.enable_grad(): - (output,) = ctx.forward_function(hidden_states, *ctx.args) - torch.autograd.backward(output, dY) - return ( - None, - hidden_states.grad, - ) + ( - None, - ) * len(ctx.args) - - pass - - -pass - - -def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - assert gradient_checkpointing_kwargs == None - if not self.supports_gradient_checkpointing: - raise ValueError( - f"{self.__class__.__name__} does not support gradient checkpointing." - ) - - gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply - # For old GC format (transformers < 4.35.0) for models that live on the Hub - # we will fall back to the overwritten `_set_gradient_checkpointing` method - _is_using_old_format = ( - "value" in inspect.signature(self._set_gradient_checkpointing).parameters - ) - - if not _is_using_old_format: - self._set_gradient_checkpointing( - enable=True, gradient_checkpointing_func=gradient_checkpointing_func - ) - else: - raise NotImplementedError() - - if getattr(self, "_hf_peft_config_loaded", False): - # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True - # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 - # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate - # the gradients to make sure the gradient flows. - self.enable_input_require_grads() - - -def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): - transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( - new_gradient_checkpointing_enable - ) diff --git a/src/axolotl/integrations/easy_context/usp/monkey_patch.py b/src/axolotl/integrations/easy_context/usp/monkey_patch.py deleted file mode 100644 index 9745036f3..000000000 --- a/src/axolotl/integrations/easy_context/usp/monkey_patch.py +++ /dev/null @@ -1,114 +0,0 @@ -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -import transformers - -try: - from yunchang import LongContextAttention, set_seq_parallel_pg - - usp_attn = LongContextAttention(ring_impl_type="zigzag") -except: - print( - "If you want to use the LongContextAttention class, please install the yunchang package." - ) - usp_attn = None - - -def new_flash_attn_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, -): - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - assert attention_mask is None - assert causal is True - assert use_sliding_windows is False - attn_output = usp_attn( - query_states, - key_states, - value_states, - dropout, - softmax_scale, - causal=causal, - ) - - return attn_output - - -def new_decoder_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - assert isinstance( - self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 - ) or isinstance( - self.self_attn, - transformers.models.mistral.modeling_mistral.MistralFlashAttention2, - ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." - - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -def apply_usp_attn_monkey_patch_llama(): - transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - new_flash_attn_forward - ) - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward - ) diff --git a/src/axolotl/integrations/easy_context/usp/prepare_inputs.py b/src/axolotl/integrations/easy_context/usp/prepare_inputs.py deleted file mode 100644 index c2c4f2fad..000000000 --- a/src/axolotl/integrations/easy_context/usp/prepare_inputs.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from yunchang import set_seq_parallel_pg -from yunchang.comm import zigzag_extract_local - - -def prepare_usp_attn_inputs( - input_ids, - position_ids, - target_ids, - rank, - world_size, - device, - ring_degree, - ulysses_degree, -): - f""" - prepare input for USP attention - - USP: A Unified Sequence Parallelism Approach for Long Context Generative AI - https://arxiv.org/abs/2405.07719 - """ - - set_seq_parallel_pg(ulysses_degree, ring_degree, rank, world_size) - - local_input_ids = zigzag_extract_local( - input_ids, - rank, - world_size, - ring_degree, - ulysses_degree, - ).to(device) - - # truncate position_ids to the same size as input_ids - position_ids = position_ids[:, : local_input_ids.shape[1]] - - local_position_ids = zigzag_extract_local( - position_ids, - rank, - world_size, - ring_degree, - ulysses_degree, - ).to(device) - - if target_ids is not None: - local_target_ids = zigzag_extract_local( - target_ids, - rank, - world_size, - ring_degree, - ulysses_degree, - ).to(device) - else: - local_target_ids = None - return { - "local_input_ids": local_input_ids, - "local_position_ids": local_position_ids, - "local_target_ids": local_target_ids, - } diff --git a/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py deleted file mode 100644 index ddf7dd292..000000000 --- a/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.utils.checkpoint -from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer -from transformers.models.mistral.modeling_mistral import ( - MistralAttention, - MistralDecoderLayer, -) - - -def new_flash_attn_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, -): - assert ( - self.config._attn_implementation == "flash_attention_2" - ), "Only Flash Attention is supported." - - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - assert attention_mask is None - assert causal is True - assert use_sliding_windows is False - attn_output = zigzag_ring_flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale, - causal=causal, - ) - - return attn_output - - -def new_decoder_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - assert isinstance(self.self_attn, LlamaAttention) or isinstance( - self.self_attn, - MistralAttention, - ), "Llama and Mistral attention only are supported." - - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -def apply_zigzag_ring_attn_monkey_patch_llama(): - # LlamaAttention._flash_attention_forward = new_flash_attn_forward - ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward}) - LlamaDecoderLayer.forward = new_decoder_forward - - -def apply_zigzag_ring_attn_monkey_patch_mistral(): - # MistralAttention._flash_attention_forward = new_flash_attn_forward - ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward}) - MistralDecoderLayer.forward = new_decoder_forward diff --git a/src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py b/src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py deleted file mode 100644 index f5cea9064..000000000 --- a/src/axolotl/integrations/easy_context/zigzag_ring_attn/prepare_inputs.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch - - -def extract_local(value, rank, world_size, device, dim=1): - value_chunks = value.chunk(2 * world_size, dim=dim) - local_value = torch.cat( - [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim - ) - return local_value.to(device) - - -def prepare_zigzag_ring_attn_inputs( - input_ids, position_ids, target_ids, rank, world_size, device -): - local_input_ids = extract_local( - input_ids, - rank, - world_size, - device, - ) - local_position_ids = extract_local( - position_ids, - rank, - world_size, - device, - ) - if target_ids is not None: - local_target_ids = extract_local( - target_ids, - rank, - world_size, - device, - ) - else: - local_target_ids = None - return { - "local_input_ids": local_input_ids, - "local_position_ids": local_position_ids, - "local_target_ids": local_target_ids, - } diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 461324f32..96ecdb4b1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -549,16 +549,6 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_size > 1: - # from axolotl.integrations.easy_context import ( - # apply_seq_parallel_monkey_patch, - # ) - - # method = self.cfg.sequence_parallel_mode - # model_type = self.cfg.model_config_type - - # # Apply the monkey patch - # apply_seq_parallel_monkey_patch(method, model_type) - register_ring_attn(self.cfg.sequence_parallel_size) def patch_attention(self) -> None: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 501bf3f68..ca0d79a27 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -17,7 +17,6 @@ from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder -from axolotl.integrations.easy_context import prepare_seq_parallel_inputs from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -357,17 +356,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): **filter_map_kwargs, **drop_long_kwargs, ) - if cfg.sequence_parallel_size > 1: - train_dataset.map( - prepare_seq_parallel_inputs, - "dist_flash_attn", - lambda batch: batch["input_ids"], - lambda batch: batch["position_ids"], - lambda batch: batch["target_ids"], - accelerator.process_index, - accelerator.num_processes, - accelerator.device, - ) if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1: if eval_dataset: eval_dataset = eval_dataset.map(