adding easy_context as integration for now

This commit is contained in:
Dan Saunders
2025-03-03 19:59:10 +00:00
parent 113e9cd193
commit 3f8a43cab6
17 changed files with 4874 additions and 0 deletions

View File

@@ -0,0 +1,96 @@
import logging
from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama
from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs
from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama
from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs
from .usp.monkey_patch import apply_usp_attn_monkey_patch_llama
from .usp.prepare_inputs import prepare_usp_attn_inputs
from .zigzag_ring_attn.monkey_patch import (
apply_zigzag_ring_attn_monkey_patch_llama,
apply_zigzag_ring_attn_monkey_patch_mistral,
)
from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs
logger = logging.getLogger(__name__)
def prepare_seq_parallel_inputs(
seq_algo,
input_ids,
position_ids,
target_ids,
rank,
world_size,
device,
*args,
**kwargs,
):
if seq_algo == "zigzag_ring_attn":
return prepare_zigzag_ring_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
)
elif seq_algo == "dist_flash_attn":
return prepare_dist_flash_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
)
elif seq_algo == "ulysses_attn":
return prepare_ulysses_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
)
elif seq_algo == "usp_attn":
ring_degree = kwargs.get("ring_degree", 1)
ulysses_degree = world_size // ring_degree
logger.info(
f"Applying USP: Ring degree: {ring_degree}, Ulysses degree: {ulysses_degree}"
)
return prepare_usp_attn_inputs(
input_ids,
position_ids,
target_ids,
rank,
world_size,
device,
ulysses_degree,
ring_degree,
)
elif seq_algo == "data_parallel":
return {
"local_input_ids": input_ids.to(device),
"local_position_ids": position_ids.to(device),
"local_target_ids": target_ids.to(device),
}
else:
raise ValueError(f"Invalid seq_algo: {seq_algo}")
def apply_seq_parallel_monkey_patch(seq_algo, model):
assert seq_algo in [
"zigzag_ring_attn",
"dist_flash_attn",
"ulysses_attn",
"data_parallel",
"usp_attn",
], f"Invalid seq_algo: {seq_algo}"
assert model in ["llama", "mistral"], f"Invalid model: {model}"
if seq_algo == "data_parallel":
return
elif seq_algo == "zigzag_ring_attn" and model == "llama":
apply_zigzag_ring_attn_monkey_patch_llama()
elif seq_algo == "zigzag_ring_attn" and model == "mistral":
apply_zigzag_ring_attn_monkey_patch_mistral()
elif seq_algo == "dist_flash_attn" and model == "llama":
apply_dist_flash_attn_monkey_patch_llama()
elif seq_algo == "ulysses_attn" and model == "llama":
apply_ulysses_attn_monkey_patch_llama()
elif seq_algo == "usp_attn" and model == "llama":
apply_usp_attn_monkey_patch_llama()
else:
raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}")
def prepare_dataloader(seq_algo, dataloader, accelerator):
if seq_algo == "data_parallel":
return accelerator.prepare(dataloader)
else:
return dataloader

View File

@@ -0,0 +1,11 @@
# LightSeq
Taken from https://github.com/RulinShao/LightSeq. All credits to the authors.
```
@article{li2023lightseq,
title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS},
author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao},
journal={arXiv preprint arXiv:2310.03294},
year={2023}
}
```

View File

@@ -0,0 +1,776 @@
import math
import os
import threading
import torch
import torch.distributed as dist
from torch.distributed import P2POp, batch_isend_irecv, irecv, isend
# Sequence parallel group that the current rank belongs to.
_SEQUENCE_PARALLEL_GROUP = None
# These values enable us to change the sequence parallel sizes on the fly.
_SEQUENCE_PARALLEL_SIZE = None
_SEQUENCE_PARALLEL_RANK = None
# Global buffer for P2P
_PEER_Q = None
_PEER_K = None
_PEER_V = None
_PEER_M = None
_PEER_L = None
_PEER_O = None
_PEER_Q_BWD = None
_PEER_K_BWD = None
_PEER_V_BWD = None
_PEER_O_BWD = None
_DELTA_DQ = None
_PEER_L = None
_DELTA_DK = None
_DELTA_DV = None
_DK_DELTA_FROM_PEER = None
_DV_DELTA_FROM_PEER = None
_PEER_DO = None
_fwd_send_volume = 0
_fwd_recv_volume = 0
_bwd_send_volume = 0
_bwd_recv_volume = 0
def initialize_distributed():
if dist.is_initialized():
if dist.get_rank() == 0:
print(
"torch distributed is already initialized, "
"skipping initialization ...",
flush=True,
)
else:
if int(os.environ["RANK"]) == 0:
print("Initializing Torch distributed.")
dist.init_process_group(backend="nccl")
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
global_world_size = dist.get_world_size()
torch.cuda.set_device(dist.get_rank() % local_world_size)
_initialize_sequence_parallel()
# create_nccl_communicators()
def _initialize_sequence_parallel(sequence_parallel_size=None):
# Get world size and rank. Ensure some consistencies.
assert (
sequence_parallel_size is None
), "Multiple sequence parallel group not implemented."
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
if sequence_parallel_size is None:
sequence_parallel_size = world_size
else:
assert world_size % sequence_parallel_size == 0
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
rank = torch.distributed.get_rank()
# Build the sequence parallel groups.
global _SEQUENCE_PARALLEL_GROUP
global _SEQUENCE_PARALLEL_RANK
global _SEQUENCE_PARALLEL_SIZE
assert (
_SEQUENCE_PARALLEL_GROUP is None
), "sequence parallel group is already initialized"
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_SEQUENCE_PARALLEL_GROUP = group
_SEQUENCE_PARALLEL_RANK = ranks.index(rank)
_SEQUENCE_PARALLEL_SIZE = len(ranks)
if dist.get_rank() == 0:
print("************ Finish sequence pralell group Initialization. ***********")
# _set_global_memory_buffer()
def maybe_get_set_global_memory_buffer(q, k, v, m, l, o):
global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O
if _PEER_Q is None:
try:
if get_sequence_parallel_rank() == 0:
print("Initializing global memoery buffer.")
except:
print("Initializing global memoery buffer.")
_PEER_Q = [torch.empty_like(q) for _ in range(2)]
_PEER_K = [torch.empty_like(k) for _ in range(2)]
_PEER_V = [torch.empty_like(v) for _ in range(2)]
_PEER_M = [torch.empty_like(m) for _ in range(2)]
_PEER_L = [torch.empty_like(l) for _ in range(2)]
_PEER_O = [torch.empty_like(o) for _ in range(2)]
return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O
def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do):
global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO
if _DELTA_DQ is None:
try:
if get_sequence_parallel_rank() == 0:
print("Initializing global memoery buffer for backward.")
except:
print("Initializing global memoery buffer for backward.")
_DELTA_DQ = [torch.empty_like(dq) for _ in range(2)]
_DELTA_DK = [torch.empty_like(dk) for _ in range(2)]
_DELTA_DV = [torch.empty_like(dv) for _ in range(2)]
_PEER_L = [torch.empty_like(L) for _ in range(2)]
_DK_DELTA_FROM_PEER = torch.empty_like(dk)
_DV_DELTA_FROM_PEER = torch.empty_like(dv)
# may already be initailized in the forward call.
# current forward and backward needs a transpose in q's format
_PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)]
_PEER_K_BWD = [torch.empty_like(k) for _ in range(2)]
_PEER_V_BWD = [torch.empty_like(v) for _ in range(2)]
_PEER_O_BWD = [torch.empty_like(o) for _ in range(2)]
_PEER_DO = [torch.empty_like(do) for _ in range(2)]
return (
_DELTA_DQ,
_DELTA_DK,
_DELTA_DV,
_DK_DELTA_FROM_PEER,
_DV_DELTA_FROM_PEER,
_PEER_Q_BWD,
_PEER_L,
_PEER_K_BWD,
_PEER_V_BWD,
_PEER_O_BWD,
_PEER_DO,
)
def reset_global_memory_buffer():
global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO
_PEER_Q = None
_PEER_K = None
_PEER_V = None
_PEER_M = None
_PEER_L = None
_PEER_O = None
_DELTA_DQ = None
_PEER_L = None
_DELTA_DK = None
_DELTA_DV = None
_DK_DELTA_FROM_PEER = None
_DV_DELTA_FROM_PEER = None
_PEER_DO = None
# Pytorch defers the creation of nccl communicators to the first P2P call,
# We manually create them so the first isend does not hang without an irecv.
# reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138
# Only support even number of GPUs.
def create_nccl_communicators():
seq_rank = get_sequence_parallel_rank()
seq_group = get_sequence_parallel_group()
empty_tensor = torch.empty(
1,
).cuda()
empty_tensor_2 = torch.empty(
1,
).cuda()
if torch.distributed.get_rank() % 2 == 0:
# sender
op1 = P2POp(
op=isend,
tensor=torch.empty(
1,
).cuda(),
peer=seq_rank + 1,
group=seq_group,
)
op2 = P2POp(
op=irecv,
tensor=torch.empty(
1,
).cuda(),
peer=seq_rank + 1,
group=seq_group,
)
# req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group)
dist.batch_isend_irecv([op1, op2])
else:
# receiver
op1 = P2POp(
op=irecv,
tensor=torch.empty(
1,
).cuda(),
peer=seq_rank - 1,
group=seq_group,
)
op2 = P2POp(
op=isend,
tensor=torch.empty(
1,
).cuda(),
peer=seq_rank - 1,
group=seq_group,
)
# req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group)
handles = dist.batch_isend_irecv([op1, op2])
# req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group)
dist.all_reduce(empty_tensor, group=seq_group)
def get_sequence_parallel_group():
"""Get the sequence parallel group the caller rank belongs to."""
# global _SEQUENCE_PARALLEL_GROUP
assert (
_SEQUENCE_PARALLEL_GROUP is not None
), "sequence parallel group is not initialized"
return _SEQUENCE_PARALLEL_GROUP
def get_sequence_parallel_rank():
"""Return my rank for the sequence parallel group."""
global _SEQUENCE_PARALLEL_RANK
if _SEQUENCE_PARALLEL_RANK is not None:
return _SEQUENCE_PARALLEL_RANK
return torch.distributed.get_rank(group=get_sequence_parallel_group())
def get_sequence_parallel_size():
"""Return my rank for the sequence parallel group."""
global _SEQUENCE_PARALLEL_SIZE
if _SEQUENCE_PARALLEL_SIZE is not None:
return _SEQUENCE_PARALLEL_SIZE
return torch.distributed.get_world_size(group=get_sequence_parallel_group())
def destroy_sequence_parallel():
"""Set the groups to none."""
global _SEQUENCE_PARALLEL_GROUP
_SEQUENCE_PARALLEL_GROUP = None
# whether this is the last time the kernel being called
def is_last_time(time_step):
# e.g. on a 8-GPU setup:
# R=0: 0
# R=1: 1
# R=2: 2
# R=3: 3
# R=4: 4, 5, 6, 7
seq_rank = get_sequence_parallel_rank()
seq_world_size = get_sequence_parallel_size()
if seq_rank <= seq_world_size // 2: # no one helps these ranks
rank_finish_time = seq_rank
else:
rank_finish_time = seq_world_size // 2
return rank_finish_time == time_step
# Whether the current time step is computing for local q
def is_compute_for_local_query(time_step):
# R=3,4,5,6,7: Yes
# R=0: 0
# R=1: 0, 1
# R=2: 0, 1, 2
seq_rank = get_sequence_parallel_rank()
seq_world_size = get_sequence_parallel_size()
if seq_rank >= min(seq_world_size // 2, time_step):
return True
return False
# Whether the current time step is idle
def is_idle(time_step):
# 0, 1, 2, 3: 4
# 4, 5, 6, 7: No
seq_rank = get_sequence_parallel_rank()
seq_world_size = get_sequence_parallel_size()
if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2:
return True
return False
# Whether the current time step needs to synchronize with a remote computed result
def is_sync_from_remote(time_step):
# R=0, 1, 2, 3, 4: No
# R=5: 4
# R=6: 3, 4
# R=7: 2, 3, 4
seq_rank = get_sequence_parallel_rank()
seq_world_size = get_sequence_parallel_size()
if seq_rank > max(seq_world_size // 2, seq_world_size - time_step):
return True
return False
def maybe_send_recv_fwd_qkvo(
q: torch.Tensor,
peer_q: torch.Tensor,
k: torch.Tensor,
peer_k: torch.Tensor,
v: torch.Tensor,
peer_v: torch.Tensor,
o_stats: list, # peer_o_stats: list,
time_step: int,
comm_mode,
debug=False,
) -> torch.Tensor:
seq_group = get_sequence_parallel_group()
seq_rank = get_sequence_parallel_rank()
seq_world_size = get_sequence_parallel_size()
# Handles for operations that actually need to be wait before going to the next iteration.
# For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler;
all_handles = []
# KV logic: different than older version, every rank to send/recv its own kv,
# to balance communication. In a balanced communication, every step each rank
# should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when
# time step > 0, should send its own kv and send/recv qo. In the older version,
# rank 0 does not send its kv, and rely on a later rank to pass it, where the
# later rank has to (1) receive kv, send rank 0's kv and send/recv qo.
# Q (load balancing) logic: semantically, this will be "%" world size, so
# the same send/recv rank as KV. Note: Only support even number of machines.
# O (load balancing) logic: rank 0 sends result to rank 7 at time 1.
# It get delayed for one time step, and thus has different maybe_send/recv_rank.
# Use (time_step + 1) to easily convert to synchornize version.
maybe_send_rank = seq_rank + (time_step + 1)
maybe_recv_rank = seq_rank - (time_step + 1)
if debug:
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
_debug_send = _fwd_send_volume
_debug_recv = _fwd_recv_volume
if maybe_send_rank >= seq_world_size:
# send q, no one needs to do remote computation in the last time step
if time_step < (seq_world_size // 2 - 1):
# print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)")
# q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group))
all_handles.append(
P2POp(
op=isend,
tensor=q,
peer=maybe_send_rank % seq_world_size,
group=seq_group,
)
)
if debug:
_fwd_send_volume += torch.numel(q) * q.element_size()
else:
# send kv
# print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)")
# kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group))
# kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group))
all_handles.append(
P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)
)
all_handles.append(
P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)
)
if debug:
_fwd_send_volume += torch.numel(k) * k.element_size()
_fwd_send_volume += torch.numel(v) * v.element_size()
if maybe_recv_rank < 0:
# recv q, no one needs to do remote computation in the last time step
if time_step < (seq_world_size // 2 - 1):
# print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)")
# q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group))
all_handles.append(
P2POp(
op=irecv,
tensor=peer_q,
peer=maybe_recv_rank % seq_world_size,
group=seq_group,
)
)
if debug:
_fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size()
else:
# recv kv
# print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)")
# kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group))
# kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group))
all_handles.append(
P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)
)
all_handles.append(
P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)
)
if debug:
_fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size()
_fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size()
maybe_send_rank_o = seq_rank - (time_step - 1)
maybe_recv_rank_o = seq_rank + (time_step - 1)
if maybe_send_rank_o < 0 and time_step > 1:
for t in o_stats:
# print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)")
# o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group))
all_handles.append(
P2POp(
op=isend,
tensor=t,
peer=maybe_send_rank_o % seq_world_size,
group=seq_group,
)
)
if debug:
_fwd_send_volume += torch.numel(t) * t.element_size()
if maybe_recv_rank_o >= seq_world_size and time_step > 1:
for t in o_stats:
# print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)")
# o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group))
all_handles.append(
P2POp(
op=irecv,
tensor=t,
peer=maybe_recv_rank_o % seq_world_size,
group=seq_group,
)
)
if debug:
_fwd_recv_volume += torch.numel(t) * t.element_size()
# reqs = []
if debug:
if seq_rank in [0, 8]:
print(
f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB"
)
# return reqs
all_reqs = launch_async_handles(all_handles, comm_mode)
return [all_reqs]
# delta: may be you are using it for your local compute or as a distributed buffer to send to others
# .. Sorry for the bad naming..
def maybe_send_recv_bwd_qkvo(
dq_delta: torch.Tensor,
dk_delta: torch.Tensor,
dv_delta: torch.Tensor,
dk_delta_from_peer: torch.Tensor,
dv_delta_from_peer: torch.Tensor,
q: torch.Tensor,
peer_q: torch.Tensor,
L: torch.Tensor,
peer_L: torch.Tensor,
k: torch.Tensor,
peer_k: torch.Tensor,
v: torch.Tensor,
peer_v: torch.Tensor,
o: torch.Tensor,
peer_o: torch.Tensor,
do: torch.Tensor,
peer_do: torch.Tensor,
time_step: int,
comm_mode,
debug=False,
):
seq_group = get_sequence_parallel_group()
seq_rank = get_sequence_parallel_rank()
seq_world_size = get_sequence_parallel_size()
all_handles = []
maybe_send_rank = seq_rank + (time_step + 1)
maybe_recv_rank = seq_rank - (time_step + 1)
if debug:
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
if maybe_send_rank >= seq_world_size:
# send q, no one needs to do remote computation in the last time step
if time_step < (seq_world_size // 2 - 1):
all_handles.append(
P2POp(
op=isend,
tensor=q,
peer=maybe_send_rank % seq_world_size,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=isend,
tensor=L,
peer=maybe_send_rank % seq_world_size,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=isend,
tensor=o,
peer=maybe_send_rank % seq_world_size,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=isend,
tensor=do,
peer=maybe_send_rank % seq_world_size,
group=seq_group,
)
)
if debug:
_bwd_send_volume += torch.numel(q) * q.element_size()
_bwd_send_volume += torch.numel(L) * L.element_size()
_bwd_send_volume += torch.numel(o) * o.element_size()
_bwd_send_volume += torch.numel(do) * do.element_size()
else:
# send kv
all_handles.append(
P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)
)
all_handles.append(
P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)
)
if debug:
_bwd_send_volume += torch.numel(k) * k.element_size()
_bwd_send_volume += torch.numel(v) * v.element_size()
if maybe_recv_rank < 0:
# recv q, no one needs to do remote computation in the last time step
if time_step < (seq_world_size // 2 - 1):
all_handles.append(
P2POp(
op=irecv,
tensor=peer_q,
peer=maybe_recv_rank % seq_world_size,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=irecv,
tensor=peer_L,
peer=maybe_recv_rank % seq_world_size,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=irecv,
tensor=peer_o,
peer=maybe_recv_rank % seq_world_size,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=irecv,
tensor=peer_do,
peer=maybe_recv_rank % seq_world_size,
group=seq_group,
)
)
if debug:
_bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size()
_bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size()
_bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size()
_bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size()
else:
# recv kv
all_handles.append(
P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)
)
all_handles.append(
P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)
)
if debug:
_bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size()
_bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size()
# Whether I should update dq, dk and dv after waiting these requests
is_update_dq = False
is_update_dkv = False
maybe_send_rank_dqkv = seq_rank - (time_step - 1)
maybe_recv_rank_dqkv = seq_rank + (time_step - 1)
if time_step > 1:
if maybe_send_rank_dqkv < 0:
# print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}")
all_handles.append(
P2POp(
op=isend,
tensor=dq_delta,
peer=maybe_send_rank_dqkv % seq_world_size,
group=seq_group,
)
)
if debug:
_bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size()
else:
# print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}")
all_handles.append(
P2POp(
op=isend,
tensor=dk_delta,
peer=maybe_send_rank_dqkv,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=isend,
tensor=dv_delta,
peer=maybe_send_rank_dqkv,
group=seq_group,
)
)
if debug:
_bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size()
_bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size()
if maybe_recv_rank_dqkv >= seq_world_size:
# print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}")
all_handles.append(
P2POp(
op=irecv,
tensor=dq_delta,
peer=maybe_recv_rank_dqkv % seq_world_size,
group=seq_group,
)
)
is_update_dq = True
if debug:
_bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size()
else:
# print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}")
all_handles.append(
P2POp(
op=irecv,
tensor=dk_delta_from_peer,
peer=maybe_recv_rank_dqkv,
group=seq_group,
)
)
all_handles.append(
P2POp(
op=irecv,
tensor=dv_delta_from_peer,
peer=maybe_recv_rank_dqkv,
group=seq_group,
)
)
is_update_dkv = True
if debug:
_bwd_recv_volume += (
torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size()
)
_bwd_recv_volume += (
torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size()
)
# return [], is_update_dq, is_update_dkv
all_reqs = launch_async_handles(all_handles, comm_mode)
return [all_reqs], is_update_dq, is_update_dkv
def maybe_send_recv_bwd_last_dkv(
dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False
):
is_update_last_dkv = False
seq_group = get_sequence_parallel_group()
seq_rank = get_sequence_parallel_rank()
seq_world_size = get_sequence_parallel_size()
if seq_world_size == 1:
return [], is_update_last_dkv
all_handles = []
if debug:
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
if time_step == seq_world_size // 2:
maybe_send_rank = seq_rank - time_step
maybe_recv_rank = seq_rank + time_step
assert (maybe_send_rank >= 0) ^ (
maybe_recv_rank < seq_world_size
), "R={seq_rank} should be either sending or receiving dkv in the last time step."
if maybe_send_rank >= 0:
# print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}")
all_handles.append(
P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group)
)
all_handles.append(
P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group)
)
if debug:
_bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size()
_bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size()
if maybe_recv_rank < seq_world_size:
# print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}")
all_handles.append(
P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group)
)
all_handles.append(
P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group)
)
if debug:
_bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size()
_bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size()
is_update_last_dkv = True
# return [], is_update_last_dkv
all_reqs = launch_async_handles(all_handles, comm_mode)
return [all_reqs], is_update_last_dkv
def print_and_reset_comm_stats():
seq_rank = get_sequence_parallel_rank()
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
_fwd_send_volume *= 1e-9
_fwd_recv_volume *= 1e-9
_bwd_send_volume *= 1e-9
_bwd_recv_volume *= 1e-9
print(
f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB."
)
_fwd_send_volume = 0
_fwd_recv_volume = 0
_bwd_send_volume = 0
_bwd_recv_volume = 0
def launch_async_handles(handles, comm_mode):
global _args
if comm_mode == "nocomm":
# print("skipping communication for ablation")
return []
if len(handles) > 0:
return dist.batch_isend_irecv(handles)
return []
def wait_async_handles(reqs):
if len(reqs) > 0:
for req in reqs:
for r in req:
r.wait()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,755 @@
"""
Materialization-aware gradient checkpointing monkey patch.
"""
from typing import List, Optional, Tuple
import torch
import transformers
from einops import rearrange
from torch import nn
from torch.utils.checkpoint import (
_get_autocast_kwargs,
check_backward_validity,
detach_variable,
get_device_states,
set_device_states,
)
from transformers.models.llama.modeling_llama import (
BaseModelOutputWithPast,
apply_rotary_pos_emb,
)
from .async_communication import initialize_distributed, reset_global_memory_buffer
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
# define a global buffer to save flash attention outputs
# it's called global because it saves the outputs for all layers
global_flash_attn_out_buffer = None
# define a local buffer to save recomputed qkv
# it's called local because it's a temporary buffer which will be updated across layers
local_res_grad_buffer = None
# hooks for the gradients of residual
global_hooks = []
def init_flash_attn_buffers(num_layers):
# update the global buffer according to number of layers
global global_flash_attn_out_buffer
global_flash_attn_out_buffer = [None] * num_layers
def clean_hook():
# Remove all hooks in the global buffer
for hook in global_hooks:
hook.remove()
# Clear the global buffer
global_hooks.clear()
def clear_all_buffers_at_the_end_of_training():
# call it at the end of training
global lobal_flash_attn_out_buffer
global_flash_attn_out_buffer = None
global local_res_grad_buffer
local_res_grad_buffer = None
clean_hook()
def save_flash_attn_out_to_global_buffer(idx, out):
global global_flash_attn_out_buffer
global_flash_attn_out_buffer[idx] = out
def get_flash_attn_out_from_global_buffer(idx):
global global_flash_attn_out_buffer
return global_flash_attn_out_buffer[idx]
def free_flash_attn_out_buffer(idx):
global global_flash_attn_out_buffer
global_flash_attn_out_buffer[idx] = None
def write_gradient_to_flash_attn_out(idx, grad):
global global_flash_attn_out_buffer
global_flash_attn_out_buffer[idx].grad = grad
def save_res_grad_hook(grad):
global local_res_grad_buffer
local_res_grad_buffer = grad
def load_and_add_res_grad_hook(grad):
grad += get_res_grad_from_local_buffer()
def get_res_grad_from_local_buffer():
global local_res_grad_buffer
assert local_res_grad_buffer is not None
return local_res_grad_buffer
class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function):
"""Avoid doing twice flash attention forward during checkpointed backward.
args:
hidden_states, # i.e., flash attention output which is saved in global buffer.
attention_mask,
position_ids,
residual, # the gradient of residual is saved in local buffer to pass across ckpt layers.
"""
@staticmethod
def forward(ctx, run_function, layer_idx, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.layer_idx = layer_idx
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if i == 0 and ctx.layer_idx != 0:
# flash attention output is saved to the global buffer during forward
ctx.inputs.append(None)
else:
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
with torch.no_grad():
q, k, v, residual = run_function(*args)
softmax_scale = q.shape[-1] ** (-0.5)
# lightseq version
_, _, _, out, softmax_lse = _lightseq_forward(
q, k, v, True, softmax_scale, comm_mode="lightseq"
)
rng_state = None
# save flash attention output to global buffer
save_flash_attn_out_to_global_buffer(ctx.layer_idx, out)
tensor_inputs += [softmax_lse]
ctx.softmax_scale = softmax_scale
ctx.save_for_backward(*tensor_inputs)
return out, residual
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
tensors, softmax_lse = tensors[:-1], tensors[-1]
# Fill in inputs with appropriate saved tensors.
# Fill the flash attention output first
if ctx.layer_idx > 0:
# inputs[0] should be flash attention output
inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx - 1)
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), torch.cuda.amp.autocast(
**ctx.gpu_autocast_kwargs
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
# Stop recomputation before flash attention
# It is unecessary to run recomputation for flash attn
q, k, v, residual = ctx.run_function(*detached_inputs)
# run backward() with only tensor that requires grad
# run flash attention backward first:
# get 'dout' from auto_grad inputs
# get 'out' from global buffer
# get 'qkv' from the recomputed tensors
# dq = torch.empty(q.shape, dtype=q.dtype, device=q.device)
# dk = torch.empty(k.shape, dtype=q.dtype, device=q.device)
# dv = torch.empty(v.shape, dtype=q.dtype, device=q.device)
out = get_flash_attn_out_from_global_buffer(ctx.layer_idx)
# todo get dout
dout = args[0]
# lightseq version
dq, dk, dv = _lightseq_backward(
dout,
q,
k,
v,
out,
softmax_lse,
ctx.softmax_scale,
comm_mode="lightseq",
backward_engine="flash",
)
# dqkv = torch.stack([dq, dk, dv])
# run backward for the part before flash attention
# qkv.backward(dqkv)
torch.autograd.backward([q, k, v], [dq, dk, dv])
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
# write flash attention output gradients to buffer
if ctx.layer_idx > 0:
write_gradient_to_flash_attn_out(ctx.layer_idx - 1, detached_inputs[0].grad)
return (None, None, None) + grads
def checkpoint_end_with_flash_attention(
function, layer_idx, *args, use_reentrant: bool = True, **kwargs
):
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop("preserve_rng_state", True)
if kwargs and use_reentrant:
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
return CheckpointFunctionEndWithFlashAttention.apply(
function, layer_idx, preserve, *args
)
class CheckpointFunctionLastModule(torch.autograd.Function):
"""
for the last ffn layer after flash attention, modifications include:
write the gradients wrt flash attention output and residual to the global buffer.
"""
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
assert torch.is_tensor(
args[0]
), "assuming the first tensor is the flash attention output"
for i, arg in enumerate(args):
if torch.is_tensor(arg) and i == 0:
# flash attn output has been saved to global buffer
ctx.inputs.append(None)
elif torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
# Fill the flash attention output first
# inputs[0] should be flash attention output
inputs[0] = get_flash_attn_out_from_global_buffer(-1)
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), torch.cuda.amp.autocast(
**ctx.gpu_autocast_kwargs
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
# write flash attention output gradients to buffer
write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad)
return (None, None) + grads
def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs):
preserve = kwargs.pop("preserve_rng_state", True)
if kwargs and use_reentrant:
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)
return CheckpointFunctionLastModule.apply(function, preserve, *args)
def llama_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
compute_attn_only: Optional[bool] = False,
compute_ffn_only: Optional[bool] = False,
residual: Optional[bool] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
assert compute_ffn_only or compute_attn_only
if compute_attn_only:
residual = hidden_states
if residual.requires_grad:
# register a hook to add the gradient of residual
# from next checkpoint layer when doing recomputation
hook = residual.register_hook(load_and_add_res_grad_hook)
global_hooks.append(hook)
hidden_states = self.input_layernorm(hidden_states)
# Flash Attention
bsz, q_len, _ = hidden_states.size()
try:
query_states = (
self.self_attn.q_proj(hidden_states)
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
.transpose(1, 2)
)
key_states = (
self.self_attn.k_proj(hidden_states)
.view(
bsz,
q_len,
self.self_attn.num_key_value_heads,
self.self_attn.head_dim,
)
.transpose(1, 2)
)
value_states = (
self.self_attn.v_proj(hidden_states)
.view(
bsz,
q_len,
self.self_attn.num_key_value_heads,
self.self_attn.head_dim,
)
.transpose(1, 2)
)
except:
# old transformers versions don't support num_key_value_heads
query_states = (
self.self_attn.q_proj(hidden_states)
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
.transpose(1, 2)
)
key_states = (
self.self_attn.k_proj(hidden_states)
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
.transpose(1, 2)
)
value_states = (
self.self_attn.v_proj(hidden_states)
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.self_attn.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
return (
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
residual,
)
elif compute_ffn_only:
hidden_states = self.self_attn.o_proj(
rearrange(hidden_states, "b h s d -> b s (h d)")
)
# Need to add residual here to make sure checkpoint is right after attention
if residual.requires_grad:
# save the gradient of residual to the local buffer
# collect the hooks which should be removed after backward to avoid memory leak
hook = residual.register_hook(save_res_grad_hook)
global_hooks.append(hook)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
else:
raise AttributeError
return outputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
):
assert cache_position is None, "cache_position is not supported"
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
attention_mask = None
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
try:
logger.warning_once("***** Using fast gradient checkpointing... *****")
except:
pass
# initialize the global buffer
init_flash_attn_buffers(len(self.layers))
if use_cache:
try:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
except:
pass
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# apply flash-attention friendly gradient checkpointing
if self.gradient_checkpointing and self.training:
for idx in range(len(self.layers) + 1):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
def forward_first_attn_module(module):
def custom_forward(*inputs):
hidden_states, attention_mask, position_ids, _ = inputs
# None for past_key_value
return module(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
compute_attn_only=True,
)
return custom_forward
def forward_ffn_attn_layer(module1, module2):
def custom_forward(*inputs):
hidden_states, attention_mask, position_ids, residual = inputs
# None for past_key_value
layer_outputs = module1(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
compute_ffn_only=True,
residual=residual,
)
hidden_states = layer_outputs[0]
return module2(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
compute_attn_only=True,
)
return custom_forward
def forward_last_ffn_module(module):
def custom_forward(*inputs):
hidden_states, attention_mask, position_ids, residual = inputs
# None for past_key_value
return module(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
compute_ffn_only=True,
residual=residual,
)
return custom_forward
if idx == 0:
layer_outputs = checkpoint_end_with_flash_attention(
forward_first_attn_module(self.layers[0]),
idx,
hidden_states,
attention_mask,
position_ids,
None,
)
hidden_states, residual = layer_outputs[0], layer_outputs[-1]
elif idx == len(self.layers):
layer_outputs = checkpoint_last_module(
forward_last_ffn_module(self.layers[-1]),
hidden_states,
attention_mask,
position_ids,
residual,
)
hidden_states = layer_outputs[0]
else:
layer_outputs = checkpoint_end_with_flash_attention(
forward_ffn_attn_layer(self.layers[idx - 1], self.layers[idx]),
idx,
hidden_states,
attention_mask,
position_ids,
residual,
)
hidden_states, residual = layer_outputs[0], layer_outputs[-1]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
else:
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def apply_dist_flash_attn_monkey_patch_llama():
initialize_distributed()
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
llama_layer_forward
)

View File

@@ -0,0 +1,34 @@
def extract_local(value, rank, world_size, device, dim=1):
value_local = value.chunk(world_size, dim=dim)[rank]
return value_local.to(device)
def prepare_dist_flash_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
):
local_input_ids = extract_local(
input_ids,
rank,
world_size,
device,
)
local_position_ids = extract_local(
position_ids,
rank,
world_size,
device,
)
if target_ids is not None:
local_target_ids = extract_local(
target_ids,
rank,
world_size,
device,
)
else:
local_target_ids = None
return {
"local_input_ids": local_input_ids,
"local_position_ids": local_position_ids,
"local_target_ids": local_target_ids,
}

View File

@@ -0,0 +1,114 @@
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import transformers
try:
from yunchang.ulysses import UlyssesAttention
ulysses_attn = UlyssesAttention()
except:
print(
"If you want to use the UlyssesAttention class, please install the yunchang package."
)
ulysses_attn = None
def new_flash_attn_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
assert attention_mask is None
assert causal is True
assert use_sliding_windows is False
attn_output = ulysses_attn(
query_states,
key_states,
value_states,
dropout,
softmax_scale,
causal=causal,
)
return attn_output
def new_decoder_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
assert isinstance(
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
) or isinstance(
self.self_attn,
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def apply_ulysses_attn_monkey_patch_llama():
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
new_flash_attn_forward
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)

View File

@@ -0,0 +1,44 @@
import torch
def extract_local(value, rank, world_size, device, dim=1):
dimension_size = value.shape[dim]
sub_seq_length = dimension_size // world_size
sub_seq_start = rank * sub_seq_length
sub_seq_end = (rank + 1) * sub_seq_length
local_value = value[:, sub_seq_start:sub_seq_end]
return local_value.to(device)
def prepare_ulysses_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
):
local_input_ids = extract_local(
input_ids,
rank,
world_size,
device,
)
local_position_ids = extract_local(
position_ids,
rank,
world_size,
device,
)
if target_ids is not None:
local_target_ids = extract_local(
target_ids,
rank,
world_size,
device,
)
else:
local_target_ids = None
return {
"local_input_ids": local_input_ids,
"local_position_ids": local_position_ids,
"local_target_ids": local_target_ids,
}

View File

@@ -0,0 +1,95 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import torch
import transformers
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
pass
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)
pass
pass
def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
assert gradient_checkpointing_kwargs == None
if not self.supports_gradient_checkpointing:
raise ValueError(
f"{self.__class__.__name__} does not support gradient checkpointing."
)
gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = (
"value" in inspect.signature(self._set_gradient_checkpointing).parameters
)
if not _is_using_old_format:
self._set_gradient_checkpointing(
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
)
else:
raise NotImplementedError()
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
new_gradient_checkpointing_enable
)

View File

@@ -0,0 +1,114 @@
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import transformers
try:
from yunchang import LongContextAttention, set_seq_parallel_pg
usp_attn = LongContextAttention(ring_impl_type="zigzag")
except:
print(
"If you want to use the LongContextAttention class, please install the yunchang package."
)
usp_attn = None
def new_flash_attn_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
assert attention_mask is None
assert causal is True
assert use_sliding_windows is False
attn_output = usp_attn(
query_states,
key_states,
value_states,
dropout,
softmax_scale,
causal=causal,
)
return attn_output
def new_decoder_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
assert isinstance(
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
) or isinstance(
self.self_attn,
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def apply_usp_attn_monkey_patch_llama():
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
new_flash_attn_forward
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)

View File

@@ -0,0 +1,58 @@
import torch
from yunchang import set_seq_parallel_pg
from yunchang.comm import zigzag_extract_local
def prepare_usp_attn_inputs(
input_ids,
position_ids,
target_ids,
rank,
world_size,
device,
ring_degree,
ulysses_degree,
):
f"""
prepare input for USP attention
USP: A Unified Sequence Parallelism Approach for Long Context Generative AI
https://arxiv.org/abs/2405.07719
"""
set_seq_parallel_pg(ulysses_degree, ring_degree, rank, world_size)
local_input_ids = zigzag_extract_local(
input_ids,
rank,
world_size,
ring_degree,
ulysses_degree,
).to(device)
# truncate position_ids to the same size as input_ids
position_ids = position_ids[:, : local_input_ids.shape[1]]
local_position_ids = zigzag_extract_local(
position_ids,
rank,
world_size,
ring_degree,
ulysses_degree,
).to(device)
if target_ids is not None:
local_target_ids = zigzag_extract_local(
target_ids,
rank,
world_size,
ring_degree,
ulysses_degree,
).to(device)
else:
local_target_ids = None
return {
"local_input_ids": local_input_ids,
"local_position_ids": local_position_ids,
"local_target_ids": local_target_ids,
}

View File

@@ -0,0 +1,114 @@
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import transformers
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
def new_flash_attn_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
assert attention_mask is None
assert causal is True
assert use_sliding_windows is False
attn_output = zigzag_ring_flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale,
causal=causal,
)
return attn_output
def new_decoder_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
assert isinstance(
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
) or isinstance(
self.self_attn,
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def apply_zigzag_ring_attn_monkey_patch_llama():
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
new_flash_attn_forward
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)
def apply_zigzag_ring_attn_monkey_patch_mistral():
transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = (
new_flash_attn_forward
)
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = (
new_decoder_forward
)

View File

@@ -0,0 +1,40 @@
import torch
def extract_local(value, rank, world_size, device, dim=1):
value_chunks = value.chunk(2 * world_size, dim=dim)
local_value = torch.cat(
[value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim
)
return local_value.to(device)
def prepare_zigzag_ring_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
):
local_input_ids = extract_local(
input_ids,
rank,
world_size,
device,
)
local_position_ids = extract_local(
position_ids,
rank,
world_size,
device,
)
if target_ids is not None:
local_target_ids = extract_local(
target_ids,
rank,
world_size,
device,
)
else:
local_target_ids = None
return {
"local_input_ids": local_input_ids,
"local_position_ids": local_position_ids,
"local_target_ids": local_target_ids,
}

View File

@@ -0,0 +1,123 @@
"""
Utilities for sequence parallelism implementation.
Modified from:
https://github.com/Qihoo360/360-LLaMA-Factory/blob/f295a5760cceebe069fb5b975813d2c945598acb/src/llamafactory/model/model_utils/sequence_parallel.py
"""
from functools import partial
import torch.distributed as dist
import transformers
import transformers.modeling_attn_mask_utils
from ring_flash_attn import (
ring_flash_attn_func,
stripe_flash_attn_func,
zigzag_ring_flash_attn_func,
)
def ring_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0,
sliding_window=None,
is_causal=True,
group=None,
**kwargs,
):
attn_output = ring_flash_attn_func(
query_states, key_states, value_states, dropout, causal=is_causal, group=group
)
return attn_output
def zigzag_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0,
sliding_window=None,
is_causal=True,
group=None,
**kwargs,
):
attn_output = zigzag_ring_flash_attn_func(
query_states, key_states, value_states, dropout, causal=is_causal, group=group
)
return attn_output
def stripe_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0,
sliding_window=None,
is_causal=True,
group=None,
**kwargs,
):
attn_output = stripe_flash_attn_func(
query_states, key_states, value_states, dropout, causal=is_causal, group=group
)
return attn_output
def init_sp_group(sp_size):
assert dist.is_initialized()
world_size = dist.get_world_size()
assert (
world_size % sp_size == 0
), "Total number of GPUs must be a multiple of sequence_parallel_size."
sp_group_num = world_size // sp_size
sp_ranks_list = [
list(range(i * sp_size, i * sp_size + sp_size)) for i in range(sp_group_num)
]
sp_groups = [dist.new_group(sp_ranks_this) for sp_ranks_this in sp_ranks_list]
global_rank_this = dist.get_rank()
sp_idx = global_rank_this // sp_size
return sp_groups[sp_idx]
def apply_sequence_parallel(cfg):
if cfg.sequence_parallel_size == 1:
return None # no sequence parallelism
# init sequence-parallel groups here
group_this = init_sp_group(cfg.sequence_parallel_size)
if cfg.sequence_parallel_mode == "ring":
new_flash_attention_forward = partial(ring_flash_attn_forward, group=group_this)
elif cfg.sequence_parallel_mode == "zigzag-ring":
new_flash_attention_forward = partial(
zigzag_flash_attn_forward, group=group_this
)
elif cfg.sequence_parallel_mode == "stripe":
new_flash_attention_forward = partial(
stripe_flash_attn_forward, group=group_this
)
else:
raise NotImplementedError(
"Other sequence parallel modes are to be implemented."
)
# monkey patching
transformers.modeling_flash_attention_utils._flash_attention_forward = (
new_flash_attention_forward
)
return group_this

View File

@@ -547,6 +547,20 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_size > 1:
from axolotl.integrations.easy_context import (
apply_seq_parallel_monkey_patch,
)
method = self.cfg.sequence_parallel_mode
model_type = self.cfg.model_type
# Apply the monkey patch
apply_seq_parallel_monkey_patch(method, model_type)
# Ensure flash attention is enabled when loading the model
self.cfg.attn_implementation = "flash_attention_2"
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention: