adding easy_context as integration for now
This commit is contained in:
@@ -36,6 +36,7 @@ einops
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
@@ -64,3 +65,6 @@ schedulefree==1.3.0
|
|||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
|
# for sequence parallelism
|
||||||
|
ring-flash-attn>=0.1.4
|
||||||
|
|||||||
96
src/axolotl/integrations/easy_context/__init__.py
Normal file
96
src/axolotl/integrations/easy_context/__init__.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama
|
||||||
|
from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs
|
||||||
|
from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama
|
||||||
|
from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs
|
||||||
|
from .usp.monkey_patch import apply_usp_attn_monkey_patch_llama
|
||||||
|
from .usp.prepare_inputs import prepare_usp_attn_inputs
|
||||||
|
from .zigzag_ring_attn.monkey_patch import (
|
||||||
|
apply_zigzag_ring_attn_monkey_patch_llama,
|
||||||
|
apply_zigzag_ring_attn_monkey_patch_mistral,
|
||||||
|
)
|
||||||
|
from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_seq_parallel_inputs(
|
||||||
|
seq_algo,
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
target_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if seq_algo == "zigzag_ring_attn":
|
||||||
|
return prepare_zigzag_ring_attn_inputs(
|
||||||
|
input_ids, position_ids, target_ids, rank, world_size, device
|
||||||
|
)
|
||||||
|
elif seq_algo == "dist_flash_attn":
|
||||||
|
return prepare_dist_flash_attn_inputs(
|
||||||
|
input_ids, position_ids, target_ids, rank, world_size, device
|
||||||
|
)
|
||||||
|
elif seq_algo == "ulysses_attn":
|
||||||
|
return prepare_ulysses_attn_inputs(
|
||||||
|
input_ids, position_ids, target_ids, rank, world_size, device
|
||||||
|
)
|
||||||
|
elif seq_algo == "usp_attn":
|
||||||
|
ring_degree = kwargs.get("ring_degree", 1)
|
||||||
|
ulysses_degree = world_size // ring_degree
|
||||||
|
logger.info(
|
||||||
|
f"Applying USP: Ring degree: {ring_degree}, Ulysses degree: {ulysses_degree}"
|
||||||
|
)
|
||||||
|
return prepare_usp_attn_inputs(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
target_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
ulysses_degree,
|
||||||
|
ring_degree,
|
||||||
|
)
|
||||||
|
elif seq_algo == "data_parallel":
|
||||||
|
return {
|
||||||
|
"local_input_ids": input_ids.to(device),
|
||||||
|
"local_position_ids": position_ids.to(device),
|
||||||
|
"local_target_ids": target_ids.to(device),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid seq_algo: {seq_algo}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_seq_parallel_monkey_patch(seq_algo, model):
|
||||||
|
assert seq_algo in [
|
||||||
|
"zigzag_ring_attn",
|
||||||
|
"dist_flash_attn",
|
||||||
|
"ulysses_attn",
|
||||||
|
"data_parallel",
|
||||||
|
"usp_attn",
|
||||||
|
], f"Invalid seq_algo: {seq_algo}"
|
||||||
|
assert model in ["llama", "mistral"], f"Invalid model: {model}"
|
||||||
|
if seq_algo == "data_parallel":
|
||||||
|
return
|
||||||
|
elif seq_algo == "zigzag_ring_attn" and model == "llama":
|
||||||
|
apply_zigzag_ring_attn_monkey_patch_llama()
|
||||||
|
elif seq_algo == "zigzag_ring_attn" and model == "mistral":
|
||||||
|
apply_zigzag_ring_attn_monkey_patch_mistral()
|
||||||
|
elif seq_algo == "dist_flash_attn" and model == "llama":
|
||||||
|
apply_dist_flash_attn_monkey_patch_llama()
|
||||||
|
elif seq_algo == "ulysses_attn" and model == "llama":
|
||||||
|
apply_ulysses_attn_monkey_patch_llama()
|
||||||
|
elif seq_algo == "usp_attn" and model == "llama":
|
||||||
|
apply_usp_attn_monkey_patch_llama()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataloader(seq_algo, dataloader, accelerator):
|
||||||
|
if seq_algo == "data_parallel":
|
||||||
|
return accelerator.prepare(dataloader)
|
||||||
|
else:
|
||||||
|
return dataloader
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
# LightSeq
|
||||||
|
Taken from https://github.com/RulinShao/LightSeq. All credits to the authors.
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{li2023lightseq,
|
||||||
|
title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS},
|
||||||
|
author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao},
|
||||||
|
journal={arXiv preprint arXiv:2310.03294},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -0,0 +1,776 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import P2POp, batch_isend_irecv, irecv, isend
|
||||||
|
|
||||||
|
# Sequence parallel group that the current rank belongs to.
|
||||||
|
_SEQUENCE_PARALLEL_GROUP = None
|
||||||
|
|
||||||
|
# These values enable us to change the sequence parallel sizes on the fly.
|
||||||
|
_SEQUENCE_PARALLEL_SIZE = None
|
||||||
|
_SEQUENCE_PARALLEL_RANK = None
|
||||||
|
|
||||||
|
# Global buffer for P2P
|
||||||
|
_PEER_Q = None
|
||||||
|
_PEER_K = None
|
||||||
|
_PEER_V = None
|
||||||
|
_PEER_M = None
|
||||||
|
_PEER_L = None
|
||||||
|
_PEER_O = None
|
||||||
|
_PEER_Q_BWD = None
|
||||||
|
_PEER_K_BWD = None
|
||||||
|
_PEER_V_BWD = None
|
||||||
|
_PEER_O_BWD = None
|
||||||
|
|
||||||
|
_DELTA_DQ = None
|
||||||
|
_PEER_L = None
|
||||||
|
_DELTA_DK = None
|
||||||
|
_DELTA_DV = None
|
||||||
|
_DK_DELTA_FROM_PEER = None
|
||||||
|
_DV_DELTA_FROM_PEER = None
|
||||||
|
_PEER_DO = None
|
||||||
|
|
||||||
|
|
||||||
|
_fwd_send_volume = 0
|
||||||
|
_fwd_recv_volume = 0
|
||||||
|
_bwd_send_volume = 0
|
||||||
|
_bwd_recv_volume = 0
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_distributed():
|
||||||
|
if dist.is_initialized():
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print(
|
||||||
|
"torch distributed is already initialized, "
|
||||||
|
"skipping initialization ...",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if int(os.environ["RANK"]) == 0:
|
||||||
|
print("Initializing Torch distributed.")
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||||
|
global_world_size = dist.get_world_size()
|
||||||
|
torch.cuda.set_device(dist.get_rank() % local_world_size)
|
||||||
|
|
||||||
|
_initialize_sequence_parallel()
|
||||||
|
|
||||||
|
|
||||||
|
# create_nccl_communicators()
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_sequence_parallel(sequence_parallel_size=None):
|
||||||
|
# Get world size and rank. Ensure some consistencies.
|
||||||
|
assert (
|
||||||
|
sequence_parallel_size is None
|
||||||
|
), "Multiple sequence parallel group not implemented."
|
||||||
|
assert torch.distributed.is_initialized()
|
||||||
|
world_size: int = torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
if sequence_parallel_size is None:
|
||||||
|
sequence_parallel_size = world_size
|
||||||
|
else:
|
||||||
|
assert world_size % sequence_parallel_size == 0
|
||||||
|
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
|
||||||
|
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
|
# Build the sequence parallel groups.
|
||||||
|
global _SEQUENCE_PARALLEL_GROUP
|
||||||
|
global _SEQUENCE_PARALLEL_RANK
|
||||||
|
global _SEQUENCE_PARALLEL_SIZE
|
||||||
|
|
||||||
|
assert (
|
||||||
|
_SEQUENCE_PARALLEL_GROUP is None
|
||||||
|
), "sequence parallel group is already initialized"
|
||||||
|
for i in range(num_sequence_parallel_groups):
|
||||||
|
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
|
||||||
|
group = torch.distributed.new_group(ranks)
|
||||||
|
if rank in ranks:
|
||||||
|
_SEQUENCE_PARALLEL_GROUP = group
|
||||||
|
_SEQUENCE_PARALLEL_RANK = ranks.index(rank)
|
||||||
|
_SEQUENCE_PARALLEL_SIZE = len(ranks)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print("************ Finish sequence pralell group Initialization. ***********")
|
||||||
|
# _set_global_memory_buffer()
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_get_set_global_memory_buffer(q, k, v, m, l, o):
|
||||||
|
global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O
|
||||||
|
if _PEER_Q is None:
|
||||||
|
try:
|
||||||
|
if get_sequence_parallel_rank() == 0:
|
||||||
|
print("Initializing global memoery buffer.")
|
||||||
|
except:
|
||||||
|
print("Initializing global memoery buffer.")
|
||||||
|
_PEER_Q = [torch.empty_like(q) for _ in range(2)]
|
||||||
|
_PEER_K = [torch.empty_like(k) for _ in range(2)]
|
||||||
|
_PEER_V = [torch.empty_like(v) for _ in range(2)]
|
||||||
|
_PEER_M = [torch.empty_like(m) for _ in range(2)]
|
||||||
|
_PEER_L = [torch.empty_like(l) for _ in range(2)]
|
||||||
|
_PEER_O = [torch.empty_like(o) for _ in range(2)]
|
||||||
|
|
||||||
|
return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do):
|
||||||
|
global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO
|
||||||
|
if _DELTA_DQ is None:
|
||||||
|
try:
|
||||||
|
if get_sequence_parallel_rank() == 0:
|
||||||
|
print("Initializing global memoery buffer for backward.")
|
||||||
|
except:
|
||||||
|
print("Initializing global memoery buffer for backward.")
|
||||||
|
_DELTA_DQ = [torch.empty_like(dq) for _ in range(2)]
|
||||||
|
_DELTA_DK = [torch.empty_like(dk) for _ in range(2)]
|
||||||
|
_DELTA_DV = [torch.empty_like(dv) for _ in range(2)]
|
||||||
|
_PEER_L = [torch.empty_like(L) for _ in range(2)]
|
||||||
|
|
||||||
|
_DK_DELTA_FROM_PEER = torch.empty_like(dk)
|
||||||
|
_DV_DELTA_FROM_PEER = torch.empty_like(dv)
|
||||||
|
|
||||||
|
# may already be initailized in the forward call.
|
||||||
|
# current forward and backward needs a transpose in q's format
|
||||||
|
_PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)]
|
||||||
|
_PEER_K_BWD = [torch.empty_like(k) for _ in range(2)]
|
||||||
|
_PEER_V_BWD = [torch.empty_like(v) for _ in range(2)]
|
||||||
|
_PEER_O_BWD = [torch.empty_like(o) for _ in range(2)]
|
||||||
|
|
||||||
|
_PEER_DO = [torch.empty_like(do) for _ in range(2)]
|
||||||
|
|
||||||
|
return (
|
||||||
|
_DELTA_DQ,
|
||||||
|
_DELTA_DK,
|
||||||
|
_DELTA_DV,
|
||||||
|
_DK_DELTA_FROM_PEER,
|
||||||
|
_DV_DELTA_FROM_PEER,
|
||||||
|
_PEER_Q_BWD,
|
||||||
|
_PEER_L,
|
||||||
|
_PEER_K_BWD,
|
||||||
|
_PEER_V_BWD,
|
||||||
|
_PEER_O_BWD,
|
||||||
|
_PEER_DO,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_global_memory_buffer():
|
||||||
|
global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO
|
||||||
|
_PEER_Q = None
|
||||||
|
_PEER_K = None
|
||||||
|
_PEER_V = None
|
||||||
|
_PEER_M = None
|
||||||
|
_PEER_L = None
|
||||||
|
_PEER_O = None
|
||||||
|
|
||||||
|
_DELTA_DQ = None
|
||||||
|
_PEER_L = None
|
||||||
|
_DELTA_DK = None
|
||||||
|
_DELTA_DV = None
|
||||||
|
_DK_DELTA_FROM_PEER = None
|
||||||
|
_DV_DELTA_FROM_PEER = None
|
||||||
|
_PEER_DO = None
|
||||||
|
|
||||||
|
|
||||||
|
# Pytorch defers the creation of nccl communicators to the first P2P call,
|
||||||
|
# We manually create them so the first isend does not hang without an irecv.
|
||||||
|
# reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138
|
||||||
|
# Only support even number of GPUs.
|
||||||
|
def create_nccl_communicators():
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_group = get_sequence_parallel_group()
|
||||||
|
|
||||||
|
empty_tensor = torch.empty(
|
||||||
|
1,
|
||||||
|
).cuda()
|
||||||
|
empty_tensor_2 = torch.empty(
|
||||||
|
1,
|
||||||
|
).cuda()
|
||||||
|
if torch.distributed.get_rank() % 2 == 0:
|
||||||
|
# sender
|
||||||
|
op1 = P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=torch.empty(
|
||||||
|
1,
|
||||||
|
).cuda(),
|
||||||
|
peer=seq_rank + 1,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
op2 = P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=torch.empty(
|
||||||
|
1,
|
||||||
|
).cuda(),
|
||||||
|
peer=seq_rank + 1,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
# req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group)
|
||||||
|
dist.batch_isend_irecv([op1, op2])
|
||||||
|
else:
|
||||||
|
# receiver
|
||||||
|
op1 = P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=torch.empty(
|
||||||
|
1,
|
||||||
|
).cuda(),
|
||||||
|
peer=seq_rank - 1,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
op2 = P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=torch.empty(
|
||||||
|
1,
|
||||||
|
).cuda(),
|
||||||
|
peer=seq_rank - 1,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
# req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group)
|
||||||
|
handles = dist.batch_isend_irecv([op1, op2])
|
||||||
|
# req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group)
|
||||||
|
dist.all_reduce(empty_tensor, group=seq_group)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sequence_parallel_group():
|
||||||
|
"""Get the sequence parallel group the caller rank belongs to."""
|
||||||
|
# global _SEQUENCE_PARALLEL_GROUP
|
||||||
|
assert (
|
||||||
|
_SEQUENCE_PARALLEL_GROUP is not None
|
||||||
|
), "sequence parallel group is not initialized"
|
||||||
|
return _SEQUENCE_PARALLEL_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_sequence_parallel_rank():
|
||||||
|
"""Return my rank for the sequence parallel group."""
|
||||||
|
global _SEQUENCE_PARALLEL_RANK
|
||||||
|
if _SEQUENCE_PARALLEL_RANK is not None:
|
||||||
|
return _SEQUENCE_PARALLEL_RANK
|
||||||
|
return torch.distributed.get_rank(group=get_sequence_parallel_group())
|
||||||
|
|
||||||
|
|
||||||
|
def get_sequence_parallel_size():
|
||||||
|
"""Return my rank for the sequence parallel group."""
|
||||||
|
global _SEQUENCE_PARALLEL_SIZE
|
||||||
|
if _SEQUENCE_PARALLEL_SIZE is not None:
|
||||||
|
return _SEQUENCE_PARALLEL_SIZE
|
||||||
|
return torch.distributed.get_world_size(group=get_sequence_parallel_group())
|
||||||
|
|
||||||
|
|
||||||
|
def destroy_sequence_parallel():
|
||||||
|
"""Set the groups to none."""
|
||||||
|
global _SEQUENCE_PARALLEL_GROUP
|
||||||
|
_SEQUENCE_PARALLEL_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
# whether this is the last time the kernel being called
|
||||||
|
def is_last_time(time_step):
|
||||||
|
# e.g. on a 8-GPU setup:
|
||||||
|
# R=0: 0
|
||||||
|
# R=1: 1
|
||||||
|
# R=2: 2
|
||||||
|
# R=3: 3
|
||||||
|
# R=4: 4, 5, 6, 7
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_world_size = get_sequence_parallel_size()
|
||||||
|
if seq_rank <= seq_world_size // 2: # no one helps these ranks
|
||||||
|
rank_finish_time = seq_rank
|
||||||
|
else:
|
||||||
|
rank_finish_time = seq_world_size // 2
|
||||||
|
return rank_finish_time == time_step
|
||||||
|
|
||||||
|
|
||||||
|
# Whether the current time step is computing for local q
|
||||||
|
def is_compute_for_local_query(time_step):
|
||||||
|
# R=3,4,5,6,7: Yes
|
||||||
|
# R=0: 0
|
||||||
|
# R=1: 0, 1
|
||||||
|
# R=2: 0, 1, 2
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_world_size = get_sequence_parallel_size()
|
||||||
|
if seq_rank >= min(seq_world_size // 2, time_step):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Whether the current time step is idle
|
||||||
|
def is_idle(time_step):
|
||||||
|
# 0, 1, 2, 3: 4
|
||||||
|
# 4, 5, 6, 7: No
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_world_size = get_sequence_parallel_size()
|
||||||
|
|
||||||
|
if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Whether the current time step needs to synchronize with a remote computed result
|
||||||
|
def is_sync_from_remote(time_step):
|
||||||
|
# R=0, 1, 2, 3, 4: No
|
||||||
|
# R=5: 4
|
||||||
|
# R=6: 3, 4
|
||||||
|
# R=7: 2, 3, 4
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_world_size = get_sequence_parallel_size()
|
||||||
|
if seq_rank > max(seq_world_size // 2, seq_world_size - time_step):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_send_recv_fwd_qkvo(
|
||||||
|
q: torch.Tensor,
|
||||||
|
peer_q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
peer_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
peer_v: torch.Tensor,
|
||||||
|
o_stats: list, # peer_o_stats: list,
|
||||||
|
time_step: int,
|
||||||
|
comm_mode,
|
||||||
|
debug=False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
seq_group = get_sequence_parallel_group()
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_world_size = get_sequence_parallel_size()
|
||||||
|
|
||||||
|
# Handles for operations that actually need to be wait before going to the next iteration.
|
||||||
|
# For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler;
|
||||||
|
all_handles = []
|
||||||
|
# KV logic: different than older version, every rank to send/recv its own kv,
|
||||||
|
# to balance communication. In a balanced communication, every step each rank
|
||||||
|
# should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when
|
||||||
|
# time step > 0, should send its own kv and send/recv qo. In the older version,
|
||||||
|
# rank 0 does not send its kv, and rely on a later rank to pass it, where the
|
||||||
|
# later rank has to (1) receive kv, send rank 0's kv and send/recv qo.
|
||||||
|
# Q (load balancing) logic: semantically, this will be "%" world size, so
|
||||||
|
# the same send/recv rank as KV. Note: Only support even number of machines.
|
||||||
|
# O (load balancing) logic: rank 0 sends result to rank 7 at time 1.
|
||||||
|
# It get delayed for one time step, and thus has different maybe_send/recv_rank.
|
||||||
|
# Use (time_step + 1) to easily convert to synchornize version.
|
||||||
|
maybe_send_rank = seq_rank + (time_step + 1)
|
||||||
|
maybe_recv_rank = seq_rank - (time_step + 1)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||||
|
_debug_send = _fwd_send_volume
|
||||||
|
_debug_recv = _fwd_recv_volume
|
||||||
|
|
||||||
|
if maybe_send_rank >= seq_world_size:
|
||||||
|
# send q, no one needs to do remote computation in the last time step
|
||||||
|
if time_step < (seq_world_size // 2 - 1):
|
||||||
|
# print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)")
|
||||||
|
# q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group))
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=q,
|
||||||
|
peer=maybe_send_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_fwd_send_volume += torch.numel(q) * q.element_size()
|
||||||
|
else:
|
||||||
|
# send kv
|
||||||
|
# print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)")
|
||||||
|
# kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group))
|
||||||
|
# kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group))
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_fwd_send_volume += torch.numel(k) * k.element_size()
|
||||||
|
_fwd_send_volume += torch.numel(v) * v.element_size()
|
||||||
|
|
||||||
|
if maybe_recv_rank < 0:
|
||||||
|
# recv q, no one needs to do remote computation in the last time step
|
||||||
|
if time_step < (seq_world_size // 2 - 1):
|
||||||
|
# print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)")
|
||||||
|
# q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group))
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=peer_q,
|
||||||
|
peer=maybe_recv_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size()
|
||||||
|
else:
|
||||||
|
# recv kv
|
||||||
|
# print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)")
|
||||||
|
# kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group))
|
||||||
|
# kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group))
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size()
|
||||||
|
_fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size()
|
||||||
|
|
||||||
|
maybe_send_rank_o = seq_rank - (time_step - 1)
|
||||||
|
maybe_recv_rank_o = seq_rank + (time_step - 1)
|
||||||
|
if maybe_send_rank_o < 0 and time_step > 1:
|
||||||
|
for t in o_stats:
|
||||||
|
# print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)")
|
||||||
|
# o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group))
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=t,
|
||||||
|
peer=maybe_send_rank_o % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_fwd_send_volume += torch.numel(t) * t.element_size()
|
||||||
|
if maybe_recv_rank_o >= seq_world_size and time_step > 1:
|
||||||
|
for t in o_stats:
|
||||||
|
# print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)")
|
||||||
|
# o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group))
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=t,
|
||||||
|
peer=maybe_recv_rank_o % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_fwd_recv_volume += torch.numel(t) * t.element_size()
|
||||||
|
|
||||||
|
# reqs = []
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
if seq_rank in [0, 8]:
|
||||||
|
print(
|
||||||
|
f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB"
|
||||||
|
)
|
||||||
|
# return reqs
|
||||||
|
all_reqs = launch_async_handles(all_handles, comm_mode)
|
||||||
|
return [all_reqs]
|
||||||
|
|
||||||
|
|
||||||
|
# delta: may be you are using it for your local compute or as a distributed buffer to send to others
|
||||||
|
# .. Sorry for the bad naming..
|
||||||
|
def maybe_send_recv_bwd_qkvo(
|
||||||
|
dq_delta: torch.Tensor,
|
||||||
|
dk_delta: torch.Tensor,
|
||||||
|
dv_delta: torch.Tensor,
|
||||||
|
dk_delta_from_peer: torch.Tensor,
|
||||||
|
dv_delta_from_peer: torch.Tensor,
|
||||||
|
q: torch.Tensor,
|
||||||
|
peer_q: torch.Tensor,
|
||||||
|
L: torch.Tensor,
|
||||||
|
peer_L: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
peer_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
peer_v: torch.Tensor,
|
||||||
|
o: torch.Tensor,
|
||||||
|
peer_o: torch.Tensor,
|
||||||
|
do: torch.Tensor,
|
||||||
|
peer_do: torch.Tensor,
|
||||||
|
time_step: int,
|
||||||
|
comm_mode,
|
||||||
|
debug=False,
|
||||||
|
):
|
||||||
|
seq_group = get_sequence_parallel_group()
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_world_size = get_sequence_parallel_size()
|
||||||
|
|
||||||
|
all_handles = []
|
||||||
|
maybe_send_rank = seq_rank + (time_step + 1)
|
||||||
|
maybe_recv_rank = seq_rank - (time_step + 1)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||||
|
|
||||||
|
if maybe_send_rank >= seq_world_size:
|
||||||
|
# send q, no one needs to do remote computation in the last time step
|
||||||
|
if time_step < (seq_world_size // 2 - 1):
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=q,
|
||||||
|
peer=maybe_send_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=L,
|
||||||
|
peer=maybe_send_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=o,
|
||||||
|
peer=maybe_send_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=do,
|
||||||
|
peer=maybe_send_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_send_volume += torch.numel(q) * q.element_size()
|
||||||
|
_bwd_send_volume += torch.numel(L) * L.element_size()
|
||||||
|
_bwd_send_volume += torch.numel(o) * o.element_size()
|
||||||
|
_bwd_send_volume += torch.numel(do) * do.element_size()
|
||||||
|
else:
|
||||||
|
# send kv
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_send_volume += torch.numel(k) * k.element_size()
|
||||||
|
_bwd_send_volume += torch.numel(v) * v.element_size()
|
||||||
|
|
||||||
|
if maybe_recv_rank < 0:
|
||||||
|
# recv q, no one needs to do remote computation in the last time step
|
||||||
|
if time_step < (seq_world_size // 2 - 1):
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=peer_q,
|
||||||
|
peer=maybe_recv_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=peer_L,
|
||||||
|
peer=maybe_recv_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=peer_o,
|
||||||
|
peer=maybe_recv_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=peer_do,
|
||||||
|
peer=maybe_recv_rank % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size()
|
||||||
|
_bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size()
|
||||||
|
_bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size()
|
||||||
|
_bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size()
|
||||||
|
else:
|
||||||
|
# recv kv
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size()
|
||||||
|
_bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size()
|
||||||
|
|
||||||
|
# Whether I should update dq, dk and dv after waiting these requests
|
||||||
|
is_update_dq = False
|
||||||
|
is_update_dkv = False
|
||||||
|
|
||||||
|
maybe_send_rank_dqkv = seq_rank - (time_step - 1)
|
||||||
|
maybe_recv_rank_dqkv = seq_rank + (time_step - 1)
|
||||||
|
|
||||||
|
if time_step > 1:
|
||||||
|
if maybe_send_rank_dqkv < 0:
|
||||||
|
# print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}")
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=dq_delta,
|
||||||
|
peer=maybe_send_rank_dqkv % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size()
|
||||||
|
else:
|
||||||
|
# print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}")
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=dk_delta,
|
||||||
|
peer=maybe_send_rank_dqkv,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=isend,
|
||||||
|
tensor=dv_delta,
|
||||||
|
peer=maybe_send_rank_dqkv,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size()
|
||||||
|
_bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size()
|
||||||
|
|
||||||
|
if maybe_recv_rank_dqkv >= seq_world_size:
|
||||||
|
# print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}")
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=dq_delta,
|
||||||
|
peer=maybe_recv_rank_dqkv % seq_world_size,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is_update_dq = True
|
||||||
|
if debug:
|
||||||
|
_bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size()
|
||||||
|
else:
|
||||||
|
# print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}")
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=dk_delta_from_peer,
|
||||||
|
peer=maybe_recv_rank_dqkv,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(
|
||||||
|
op=irecv,
|
||||||
|
tensor=dv_delta_from_peer,
|
||||||
|
peer=maybe_recv_rank_dqkv,
|
||||||
|
group=seq_group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is_update_dkv = True
|
||||||
|
if debug:
|
||||||
|
_bwd_recv_volume += (
|
||||||
|
torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size()
|
||||||
|
)
|
||||||
|
_bwd_recv_volume += (
|
||||||
|
torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
# return [], is_update_dq, is_update_dkv
|
||||||
|
all_reqs = launch_async_handles(all_handles, comm_mode)
|
||||||
|
return [all_reqs], is_update_dq, is_update_dkv
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_send_recv_bwd_last_dkv(
|
||||||
|
dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False
|
||||||
|
):
|
||||||
|
is_update_last_dkv = False
|
||||||
|
|
||||||
|
seq_group = get_sequence_parallel_group()
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
seq_world_size = get_sequence_parallel_size()
|
||||||
|
|
||||||
|
if seq_world_size == 1:
|
||||||
|
return [], is_update_last_dkv
|
||||||
|
|
||||||
|
all_handles = []
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||||
|
|
||||||
|
if time_step == seq_world_size // 2:
|
||||||
|
maybe_send_rank = seq_rank - time_step
|
||||||
|
maybe_recv_rank = seq_rank + time_step
|
||||||
|
|
||||||
|
assert (maybe_send_rank >= 0) ^ (
|
||||||
|
maybe_recv_rank < seq_world_size
|
||||||
|
), "R={seq_rank} should be either sending or receiving dkv in the last time step."
|
||||||
|
|
||||||
|
if maybe_send_rank >= 0:
|
||||||
|
# print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}")
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size()
|
||||||
|
_bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size()
|
||||||
|
if maybe_recv_rank < seq_world_size:
|
||||||
|
# print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}")
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
all_handles.append(
|
||||||
|
P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group)
|
||||||
|
)
|
||||||
|
if debug:
|
||||||
|
_bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size()
|
||||||
|
_bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size()
|
||||||
|
is_update_last_dkv = True
|
||||||
|
|
||||||
|
# return [], is_update_last_dkv
|
||||||
|
all_reqs = launch_async_handles(all_handles, comm_mode)
|
||||||
|
|
||||||
|
return [all_reqs], is_update_last_dkv
|
||||||
|
|
||||||
|
|
||||||
|
def print_and_reset_comm_stats():
|
||||||
|
seq_rank = get_sequence_parallel_rank()
|
||||||
|
|
||||||
|
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||||
|
_fwd_send_volume *= 1e-9
|
||||||
|
_fwd_recv_volume *= 1e-9
|
||||||
|
_bwd_send_volume *= 1e-9
|
||||||
|
_bwd_recv_volume *= 1e-9
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB."
|
||||||
|
)
|
||||||
|
_fwd_send_volume = 0
|
||||||
|
_fwd_recv_volume = 0
|
||||||
|
_bwd_send_volume = 0
|
||||||
|
_bwd_recv_volume = 0
|
||||||
|
|
||||||
|
|
||||||
|
def launch_async_handles(handles, comm_mode):
|
||||||
|
global _args
|
||||||
|
if comm_mode == "nocomm":
|
||||||
|
# print("skipping communication for ablation")
|
||||||
|
return []
|
||||||
|
if len(handles) > 0:
|
||||||
|
return dist.batch_isend_irecv(handles)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def wait_async_handles(reqs):
|
||||||
|
if len(reqs) > 0:
|
||||||
|
for req in reqs:
|
||||||
|
for r in req:
|
||||||
|
r.wait()
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,755 @@
|
|||||||
|
"""
|
||||||
|
Materialization-aware gradient checkpointing monkey patch.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.checkpoint import (
|
||||||
|
_get_autocast_kwargs,
|
||||||
|
check_backward_validity,
|
||||||
|
detach_variable,
|
||||||
|
get_device_states,
|
||||||
|
set_device_states,
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .async_communication import initialize_distributed, reset_global_memory_buffer
|
||||||
|
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
|
||||||
|
|
||||||
|
# define a global buffer to save flash attention outputs
|
||||||
|
# it's called global because it saves the outputs for all layers
|
||||||
|
global_flash_attn_out_buffer = None
|
||||||
|
|
||||||
|
# define a local buffer to save recomputed qkv
|
||||||
|
# it's called local because it's a temporary buffer which will be updated across layers
|
||||||
|
local_res_grad_buffer = None
|
||||||
|
|
||||||
|
# hooks for the gradients of residual
|
||||||
|
global_hooks = []
|
||||||
|
|
||||||
|
|
||||||
|
def init_flash_attn_buffers(num_layers):
|
||||||
|
# update the global buffer according to number of layers
|
||||||
|
global global_flash_attn_out_buffer
|
||||||
|
global_flash_attn_out_buffer = [None] * num_layers
|
||||||
|
|
||||||
|
|
||||||
|
def clean_hook():
|
||||||
|
# Remove all hooks in the global buffer
|
||||||
|
for hook in global_hooks:
|
||||||
|
hook.remove()
|
||||||
|
# Clear the global buffer
|
||||||
|
global_hooks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_buffers_at_the_end_of_training():
|
||||||
|
# call it at the end of training
|
||||||
|
global lobal_flash_attn_out_buffer
|
||||||
|
global_flash_attn_out_buffer = None
|
||||||
|
global local_res_grad_buffer
|
||||||
|
local_res_grad_buffer = None
|
||||||
|
clean_hook()
|
||||||
|
|
||||||
|
|
||||||
|
def save_flash_attn_out_to_global_buffer(idx, out):
|
||||||
|
global global_flash_attn_out_buffer
|
||||||
|
global_flash_attn_out_buffer[idx] = out
|
||||||
|
|
||||||
|
|
||||||
|
def get_flash_attn_out_from_global_buffer(idx):
|
||||||
|
global global_flash_attn_out_buffer
|
||||||
|
return global_flash_attn_out_buffer[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def free_flash_attn_out_buffer(idx):
|
||||||
|
global global_flash_attn_out_buffer
|
||||||
|
global_flash_attn_out_buffer[idx] = None
|
||||||
|
|
||||||
|
|
||||||
|
def write_gradient_to_flash_attn_out(idx, grad):
|
||||||
|
global global_flash_attn_out_buffer
|
||||||
|
global_flash_attn_out_buffer[idx].grad = grad
|
||||||
|
|
||||||
|
|
||||||
|
def save_res_grad_hook(grad):
|
||||||
|
global local_res_grad_buffer
|
||||||
|
local_res_grad_buffer = grad
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_add_res_grad_hook(grad):
|
||||||
|
grad += get_res_grad_from_local_buffer()
|
||||||
|
|
||||||
|
|
||||||
|
def get_res_grad_from_local_buffer():
|
||||||
|
global local_res_grad_buffer
|
||||||
|
assert local_res_grad_buffer is not None
|
||||||
|
return local_res_grad_buffer
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function):
|
||||||
|
"""Avoid doing twice flash attention forward during checkpointed backward.
|
||||||
|
args:
|
||||||
|
hidden_states, # i.e., flash attention output which is saved in global buffer.
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
residual, # the gradient of residual is saved in local buffer to pass across ckpt layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, run_function, layer_idx, preserve_rng_state, *args):
|
||||||
|
check_backward_validity(args)
|
||||||
|
ctx.run_function = run_function
|
||||||
|
ctx.layer_idx = layer_idx
|
||||||
|
ctx.preserve_rng_state = preserve_rng_state
|
||||||
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||||
|
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
||||||
|
if preserve_rng_state:
|
||||||
|
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||||
|
# Don't eagerly initialize the cuda context by accident.
|
||||||
|
# (If the user intends that the context is initialized later, within their
|
||||||
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||||
|
# we have no way to anticipate this will happen before we run the function.)
|
||||||
|
ctx.had_cuda_in_fwd = False
|
||||||
|
if torch.cuda._initialized:
|
||||||
|
ctx.had_cuda_in_fwd = True
|
||||||
|
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
||||||
|
|
||||||
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||||
|
# to be filled out during the backward.
|
||||||
|
ctx.inputs = []
|
||||||
|
ctx.tensor_indices = []
|
||||||
|
tensor_inputs = []
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if i == 0 and ctx.layer_idx != 0:
|
||||||
|
# flash attention output is saved to the global buffer during forward
|
||||||
|
ctx.inputs.append(None)
|
||||||
|
else:
|
||||||
|
if torch.is_tensor(arg):
|
||||||
|
tensor_inputs.append(arg)
|
||||||
|
ctx.tensor_indices.append(i)
|
||||||
|
ctx.inputs.append(None)
|
||||||
|
else:
|
||||||
|
ctx.inputs.append(arg)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
q, k, v, residual = run_function(*args)
|
||||||
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
|
|
||||||
|
# lightseq version
|
||||||
|
_, _, _, out, softmax_lse = _lightseq_forward(
|
||||||
|
q, k, v, True, softmax_scale, comm_mode="lightseq"
|
||||||
|
)
|
||||||
|
rng_state = None
|
||||||
|
|
||||||
|
# save flash attention output to global buffer
|
||||||
|
save_flash_attn_out_to_global_buffer(ctx.layer_idx, out)
|
||||||
|
tensor_inputs += [softmax_lse]
|
||||||
|
ctx.softmax_scale = softmax_scale
|
||||||
|
|
||||||
|
ctx.save_for_backward(*tensor_inputs)
|
||||||
|
|
||||||
|
return out, residual
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *args):
|
||||||
|
if not torch.autograd._is_checkpoint_valid():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||||
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||||
|
" argument."
|
||||||
|
)
|
||||||
|
# Copy the list to avoid modifying original list.
|
||||||
|
inputs = list(ctx.inputs)
|
||||||
|
tensor_indices = ctx.tensor_indices
|
||||||
|
tensors = ctx.saved_tensors
|
||||||
|
tensors, softmax_lse = tensors[:-1], tensors[-1]
|
||||||
|
|
||||||
|
# Fill in inputs with appropriate saved tensors.
|
||||||
|
# Fill the flash attention output first
|
||||||
|
if ctx.layer_idx > 0:
|
||||||
|
# inputs[0] should be flash attention output
|
||||||
|
inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx - 1)
|
||||||
|
for i, idx in enumerate(tensor_indices):
|
||||||
|
inputs[idx] = tensors[i]
|
||||||
|
|
||||||
|
# Stash the surrounding rng state, and mimic the state that was
|
||||||
|
# present at this time during forward. Restore the surrounding state
|
||||||
|
# when we're done.
|
||||||
|
rng_devices = []
|
||||||
|
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
||||||
|
rng_devices = ctx.fwd_gpu_devices
|
||||||
|
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
||||||
|
if ctx.preserve_rng_state:
|
||||||
|
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||||
|
if ctx.had_cuda_in_fwd:
|
||||||
|
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
|
||||||
|
detached_inputs = detach_variable(tuple(inputs))
|
||||||
|
with torch.enable_grad(), torch.cuda.amp.autocast(
|
||||||
|
**ctx.gpu_autocast_kwargs
|
||||||
|
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||||
|
# Stop recomputation before flash attention
|
||||||
|
# It is unecessary to run recomputation for flash attn
|
||||||
|
q, k, v, residual = ctx.run_function(*detached_inputs)
|
||||||
|
|
||||||
|
# run backward() with only tensor that requires grad
|
||||||
|
# run flash attention backward first:
|
||||||
|
# get 'dout' from auto_grad inputs
|
||||||
|
# get 'out' from global buffer
|
||||||
|
# get 'qkv' from the recomputed tensors
|
||||||
|
# dq = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||||
|
# dk = torch.empty(k.shape, dtype=q.dtype, device=q.device)
|
||||||
|
# dv = torch.empty(v.shape, dtype=q.dtype, device=q.device)
|
||||||
|
out = get_flash_attn_out_from_global_buffer(ctx.layer_idx)
|
||||||
|
# todo get dout
|
||||||
|
dout = args[0]
|
||||||
|
|
||||||
|
# lightseq version
|
||||||
|
dq, dk, dv = _lightseq_backward(
|
||||||
|
dout,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
softmax_lse,
|
||||||
|
ctx.softmax_scale,
|
||||||
|
comm_mode="lightseq",
|
||||||
|
backward_engine="flash",
|
||||||
|
)
|
||||||
|
# dqkv = torch.stack([dq, dk, dv])
|
||||||
|
|
||||||
|
# run backward for the part before flash attention
|
||||||
|
# qkv.backward(dqkv)
|
||||||
|
torch.autograd.backward([q, k, v], [dq, dk, dv])
|
||||||
|
|
||||||
|
grads = tuple(
|
||||||
|
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||||
|
for inp in detached_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# write flash attention output gradients to buffer
|
||||||
|
if ctx.layer_idx > 0:
|
||||||
|
write_gradient_to_flash_attn_out(ctx.layer_idx - 1, detached_inputs[0].grad)
|
||||||
|
|
||||||
|
return (None, None, None) + grads
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_end_with_flash_attention(
|
||||||
|
function, layer_idx, *args, use_reentrant: bool = True, **kwargs
|
||||||
|
):
|
||||||
|
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
||||||
|
preserve = kwargs.pop("preserve_rng_state", True)
|
||||||
|
if kwargs and use_reentrant:
|
||||||
|
raise ValueError(
|
||||||
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckpointFunctionEndWithFlashAttention.apply(
|
||||||
|
function, layer_idx, preserve, *args
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointFunctionLastModule(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
for the last ffn layer after flash attention, modifications include:
|
||||||
|
write the gradients wrt flash attention output and residual to the global buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||||
|
check_backward_validity(args)
|
||||||
|
ctx.run_function = run_function
|
||||||
|
ctx.preserve_rng_state = preserve_rng_state
|
||||||
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||||
|
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
||||||
|
if preserve_rng_state:
|
||||||
|
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||||
|
# Don't eagerly initialize the cuda context by accident.
|
||||||
|
# (If the user intends that the context is initialized later, within their
|
||||||
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||||
|
# we have no way to anticipate this will happen before we run the function.)
|
||||||
|
ctx.had_cuda_in_fwd = False
|
||||||
|
if torch.cuda._initialized:
|
||||||
|
ctx.had_cuda_in_fwd = True
|
||||||
|
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
||||||
|
|
||||||
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||||
|
# to be filled out during the backward.
|
||||||
|
ctx.inputs = []
|
||||||
|
ctx.tensor_indices = []
|
||||||
|
tensor_inputs = []
|
||||||
|
|
||||||
|
assert torch.is_tensor(
|
||||||
|
args[0]
|
||||||
|
), "assuming the first tensor is the flash attention output"
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if torch.is_tensor(arg) and i == 0:
|
||||||
|
# flash attn output has been saved to global buffer
|
||||||
|
ctx.inputs.append(None)
|
||||||
|
elif torch.is_tensor(arg):
|
||||||
|
tensor_inputs.append(arg)
|
||||||
|
ctx.tensor_indices.append(i)
|
||||||
|
ctx.inputs.append(None)
|
||||||
|
else:
|
||||||
|
ctx.inputs.append(arg)
|
||||||
|
|
||||||
|
ctx.save_for_backward(*tensor_inputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = run_function(*args)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *args):
|
||||||
|
if not torch.autograd._is_checkpoint_valid():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||||
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||||
|
" argument."
|
||||||
|
)
|
||||||
|
# Copy the list to avoid modifying original list.
|
||||||
|
inputs = list(ctx.inputs)
|
||||||
|
tensor_indices = ctx.tensor_indices
|
||||||
|
tensors = ctx.saved_tensors
|
||||||
|
|
||||||
|
# Fill in inputs with appropriate saved tensors.
|
||||||
|
# Fill the flash attention output first
|
||||||
|
# inputs[0] should be flash attention output
|
||||||
|
inputs[0] = get_flash_attn_out_from_global_buffer(-1)
|
||||||
|
for i, idx in enumerate(tensor_indices):
|
||||||
|
inputs[idx] = tensors[i]
|
||||||
|
|
||||||
|
# Stash the surrounding rng state, and mimic the state that was
|
||||||
|
# present at this time during forward. Restore the surrounding state
|
||||||
|
# when we're done.
|
||||||
|
rng_devices = []
|
||||||
|
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
||||||
|
rng_devices = ctx.fwd_gpu_devices
|
||||||
|
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
||||||
|
if ctx.preserve_rng_state:
|
||||||
|
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||||
|
if ctx.had_cuda_in_fwd:
|
||||||
|
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
|
||||||
|
detached_inputs = detach_variable(tuple(inputs))
|
||||||
|
with torch.enable_grad(), torch.cuda.amp.autocast(
|
||||||
|
**ctx.gpu_autocast_kwargs
|
||||||
|
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||||
|
outputs = ctx.run_function(*detached_inputs)
|
||||||
|
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
outputs = (outputs,)
|
||||||
|
|
||||||
|
# run backward() with only tensor that requires grad
|
||||||
|
outputs_with_grad = []
|
||||||
|
args_with_grad = []
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||||
|
outputs_with_grad.append(outputs[i])
|
||||||
|
args_with_grad.append(args[i])
|
||||||
|
if len(outputs_with_grad) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"none of output has requires_grad=True,"
|
||||||
|
" this checkpoint() is not necessary"
|
||||||
|
)
|
||||||
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||||
|
grads = tuple(
|
||||||
|
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||||
|
for inp in detached_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# write flash attention output gradients to buffer
|
||||||
|
write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad)
|
||||||
|
|
||||||
|
return (None, None) + grads
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs):
|
||||||
|
preserve = kwargs.pop("preserve_rng_state", True)
|
||||||
|
if kwargs and use_reentrant:
|
||||||
|
raise ValueError(
|
||||||
|
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckpointFunctionLastModule.apply(function, preserve, *args)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_layer_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
compute_attn_only: Optional[bool] = False,
|
||||||
|
compute_ffn_only: Optional[bool] = False,
|
||||||
|
residual: Optional[bool] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
"""
|
||||||
|
assert compute_ffn_only or compute_attn_only
|
||||||
|
|
||||||
|
if compute_attn_only:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if residual.requires_grad:
|
||||||
|
# register a hook to add the gradient of residual
|
||||||
|
# from next checkpoint layer when doing recomputation
|
||||||
|
hook = residual.register_hook(load_and_add_res_grad_hook)
|
||||||
|
global_hooks.append(hook)
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Flash Attention
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
try:
|
||||||
|
query_states = (
|
||||||
|
self.self_attn.q_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
self.self_attn.k_proj(hidden_states)
|
||||||
|
.view(
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
|
self.self_attn.num_key_value_heads,
|
||||||
|
self.self_attn.head_dim,
|
||||||
|
)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.self_attn.v_proj(hidden_states)
|
||||||
|
.view(
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
|
self.self_attn.num_key_value_heads,
|
||||||
|
self.self_attn.head_dim,
|
||||||
|
)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
# old transformers versions don't support num_key_value_heads
|
||||||
|
query_states = (
|
||||||
|
self.self_attn.q_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
self.self_attn.k_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.self_attn.v_proj(hidden_states)
|
||||||
|
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
assert past_key_value is None, "past_key_value is not supported"
|
||||||
|
|
||||||
|
cos, sin = self.self_attn.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
assert not output_attentions, "output_attentions is not supported"
|
||||||
|
assert not use_cache, "use_cache is not supported"
|
||||||
|
return (
|
||||||
|
query_states.contiguous(),
|
||||||
|
key_states.contiguous(),
|
||||||
|
value_states.contiguous(),
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif compute_ffn_only:
|
||||||
|
hidden_states = self.self_attn.o_proj(
|
||||||
|
rearrange(hidden_states, "b h s d -> b s (h d)")
|
||||||
|
)
|
||||||
|
# Need to add residual here to make sure checkpoint is right after attention
|
||||||
|
if residual.requires_grad:
|
||||||
|
# save the gradient of residual to the local buffer
|
||||||
|
# collect the hooks which should be removed after backward to avoid memory leak
|
||||||
|
hook = residual.register_hook(save_res_grad_hook)
|
||||||
|
global_hooks.append(hook)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise AttributeError
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
assert cache_position is None, "cache_position is not supported"
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
try:
|
||||||
|
logger.warning_once("***** Using fast gradient checkpointing... *****")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
# initialize the global buffer
|
||||||
|
init_flash_attn_buffers(len(self.layers))
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
try:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# apply flash-attention friendly gradient checkpointing
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
for idx in range(len(self.layers) + 1):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = (
|
||||||
|
past_key_values[idx] if past_key_values is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_first_attn_module(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
hidden_states, attention_mask, position_ids, _ = inputs
|
||||||
|
# None for past_key_value
|
||||||
|
return module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
compute_attn_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
def forward_ffn_attn_layer(module1, module2):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
hidden_states, attention_mask, position_ids, residual = inputs
|
||||||
|
# None for past_key_value
|
||||||
|
layer_outputs = module1(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
compute_ffn_only=True,
|
||||||
|
residual=residual,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
return module2(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
compute_attn_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
def forward_last_ffn_module(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
hidden_states, attention_mask, position_ids, residual = inputs
|
||||||
|
# None for past_key_value
|
||||||
|
return module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
compute_ffn_only=True,
|
||||||
|
residual=residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if idx == 0:
|
||||||
|
layer_outputs = checkpoint_end_with_flash_attention(
|
||||||
|
forward_first_attn_module(self.layers[0]),
|
||||||
|
idx,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
hidden_states, residual = layer_outputs[0], layer_outputs[-1]
|
||||||
|
elif idx == len(self.layers):
|
||||||
|
layer_outputs = checkpoint_last_module(
|
||||||
|
forward_last_ffn_module(self.layers[-1]),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
else:
|
||||||
|
layer_outputs = checkpoint_end_with_flash_attention(
|
||||||
|
forward_ffn_attn_layer(self.layers[idx - 1], self.layers[idx]),
|
||||||
|
idx,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
hidden_states, residual = layer_outputs[0], layer_outputs[-1]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
else:
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = (
|
||||||
|
past_key_values[idx] if past_key_values is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_dist_flash_attn_monkey_patch_llama():
|
||||||
|
initialize_distributed()
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||||
|
llama_layer_forward
|
||||||
|
)
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
def extract_local(value, rank, world_size, device, dim=1):
|
||||||
|
value_local = value.chunk(world_size, dim=dim)[rank]
|
||||||
|
return value_local.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dist_flash_attn_inputs(
|
||||||
|
input_ids, position_ids, target_ids, rank, world_size, device
|
||||||
|
):
|
||||||
|
local_input_ids = extract_local(
|
||||||
|
input_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
local_position_ids = extract_local(
|
||||||
|
position_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
if target_ids is not None:
|
||||||
|
local_target_ids = extract_local(
|
||||||
|
target_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
local_target_ids = None
|
||||||
|
return {
|
||||||
|
"local_input_ids": local_input_ids,
|
||||||
|
"local_position_ids": local_position_ids,
|
||||||
|
"local_target_ids": local_target_ids,
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
try:
|
||||||
|
from yunchang.ulysses import UlyssesAttention
|
||||||
|
|
||||||
|
ulysses_attn = UlyssesAttention()
|
||||||
|
except:
|
||||||
|
print(
|
||||||
|
"If you want to use the UlyssesAttention class, please install the yunchang package."
|
||||||
|
)
|
||||||
|
ulysses_attn = None
|
||||||
|
|
||||||
|
|
||||||
|
def new_flash_attn_forward(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
query_length,
|
||||||
|
dropout=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
use_sliding_windows=False,
|
||||||
|
):
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
assert attention_mask is None
|
||||||
|
assert causal is True
|
||||||
|
assert use_sliding_windows is False
|
||||||
|
attn_output = ulysses_attn(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout,
|
||||||
|
softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def new_decoder_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
assert isinstance(
|
||||||
|
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
||||||
|
) or isinstance(
|
||||||
|
self.self_attn,
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
|
||||||
|
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
|
||||||
|
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def apply_ulysses_attn_monkey_patch_llama():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
|
||||||
|
new_flash_attn_forward
|
||||||
|
)
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||||
|
new_decoder_forward
|
||||||
|
)
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def extract_local(value, rank, world_size, device, dim=1):
|
||||||
|
dimension_size = value.shape[dim]
|
||||||
|
sub_seq_length = dimension_size // world_size
|
||||||
|
|
||||||
|
sub_seq_start = rank * sub_seq_length
|
||||||
|
sub_seq_end = (rank + 1) * sub_seq_length
|
||||||
|
local_value = value[:, sub_seq_start:sub_seq_end]
|
||||||
|
|
||||||
|
return local_value.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_ulysses_attn_inputs(
|
||||||
|
input_ids, position_ids, target_ids, rank, world_size, device
|
||||||
|
):
|
||||||
|
local_input_ids = extract_local(
|
||||||
|
input_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
local_position_ids = extract_local(
|
||||||
|
position_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_ids is not None:
|
||||||
|
local_target_ids = extract_local(
|
||||||
|
target_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
local_target_ids = None
|
||||||
|
return {
|
||||||
|
"local_input_ids": local_input_ids,
|
||||||
|
"local_position_ids": local_position_ids,
|
||||||
|
"local_target_ids": local_target_ids,
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Saves VRAM by smartly offloading to RAM.
|
||||||
|
Tiny hit to performance, since we mask the movement via non blocking calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_fwd
|
||||||
|
def forward(ctx, forward_function, hidden_states, *args):
|
||||||
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = forward_function(hidden_states, *args)
|
||||||
|
ctx.save_for_backward(saved_hidden_states)
|
||||||
|
ctx.forward_function = forward_function
|
||||||
|
ctx.args = args
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_bwd
|
||||||
|
def backward(ctx, dY):
|
||||||
|
(hidden_states,) = ctx.saved_tensors
|
||||||
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
with torch.enable_grad():
|
||||||
|
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
torch.autograd.backward(output, dY)
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
hidden_states.grad,
|
||||||
|
) + (
|
||||||
|
None,
|
||||||
|
) * len(ctx.args)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||||
|
assert gradient_checkpointing_kwargs == None
|
||||||
|
if not self.supports_gradient_checkpointing:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} does not support gradient checkpointing."
|
||||||
|
)
|
||||||
|
|
||||||
|
gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
|
||||||
|
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||||
|
# we will fall back to the overwritten `_set_gradient_checkpointing` method
|
||||||
|
_is_using_old_format = (
|
||||||
|
"value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _is_using_old_format:
|
||||||
|
self._set_gradient_checkpointing(
|
||||||
|
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if getattr(self, "_hf_peft_config_loaded", False):
|
||||||
|
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||||
|
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||||
|
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||||
|
# the gradients to make sure the gradient flows.
|
||||||
|
self.enable_input_require_grads()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
|
||||||
|
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
|
||||||
|
new_gradient_checkpointing_enable
|
||||||
|
)
|
||||||
114
src/axolotl/integrations/easy_context/usp/monkey_patch.py
Normal file
114
src/axolotl/integrations/easy_context/usp/monkey_patch.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
try:
|
||||||
|
from yunchang import LongContextAttention, set_seq_parallel_pg
|
||||||
|
|
||||||
|
usp_attn = LongContextAttention(ring_impl_type="zigzag")
|
||||||
|
except:
|
||||||
|
print(
|
||||||
|
"If you want to use the LongContextAttention class, please install the yunchang package."
|
||||||
|
)
|
||||||
|
usp_attn = None
|
||||||
|
|
||||||
|
|
||||||
|
def new_flash_attn_forward(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
query_length,
|
||||||
|
dropout=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
use_sliding_windows=False,
|
||||||
|
):
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
assert attention_mask is None
|
||||||
|
assert causal is True
|
||||||
|
assert use_sliding_windows is False
|
||||||
|
attn_output = usp_attn(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout,
|
||||||
|
softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def new_decoder_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
assert isinstance(
|
||||||
|
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
||||||
|
) or isinstance(
|
||||||
|
self.self_attn,
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
|
||||||
|
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
|
||||||
|
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def apply_usp_attn_monkey_patch_llama():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
|
||||||
|
new_flash_attn_forward
|
||||||
|
)
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||||
|
new_decoder_forward
|
||||||
|
)
|
||||||
58
src/axolotl/integrations/easy_context/usp/prepare_inputs.py
Normal file
58
src/axolotl/integrations/easy_context/usp/prepare_inputs.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from yunchang import set_seq_parallel_pg
|
||||||
|
from yunchang.comm import zigzag_extract_local
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_usp_attn_inputs(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
target_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
ring_degree,
|
||||||
|
ulysses_degree,
|
||||||
|
):
|
||||||
|
f"""
|
||||||
|
prepare input for USP attention
|
||||||
|
|
||||||
|
USP: A Unified Sequence Parallelism Approach for Long Context Generative AI
|
||||||
|
https://arxiv.org/abs/2405.07719
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_seq_parallel_pg(ulysses_degree, ring_degree, rank, world_size)
|
||||||
|
|
||||||
|
local_input_ids = zigzag_extract_local(
|
||||||
|
input_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
ring_degree,
|
||||||
|
ulysses_degree,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# truncate position_ids to the same size as input_ids
|
||||||
|
position_ids = position_ids[:, : local_input_ids.shape[1]]
|
||||||
|
|
||||||
|
local_position_ids = zigzag_extract_local(
|
||||||
|
position_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
ring_degree,
|
||||||
|
ulysses_degree,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
if target_ids is not None:
|
||||||
|
local_target_ids = zigzag_extract_local(
|
||||||
|
target_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
ring_degree,
|
||||||
|
ulysses_degree,
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
local_target_ids = None
|
||||||
|
return {
|
||||||
|
"local_input_ids": local_input_ids,
|
||||||
|
"local_position_ids": local_position_ids,
|
||||||
|
"local_target_ids": local_target_ids,
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import transformers
|
||||||
|
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
|
||||||
|
|
||||||
|
|
||||||
|
def new_flash_attn_forward(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
query_length,
|
||||||
|
dropout=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
use_sliding_windows=False,
|
||||||
|
):
|
||||||
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
|
causal = self.is_causal
|
||||||
|
else:
|
||||||
|
causal = self.is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
assert attention_mask is None
|
||||||
|
assert causal is True
|
||||||
|
assert use_sliding_windows is False
|
||||||
|
attn_output = zigzag_ring_flash_attn_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
dropout,
|
||||||
|
softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def new_decoder_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
assert isinstance(
|
||||||
|
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
||||||
|
) or isinstance(
|
||||||
|
self.self_attn,
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
|
||||||
|
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
|
||||||
|
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def apply_zigzag_ring_attn_monkey_patch_llama():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
|
||||||
|
new_flash_attn_forward
|
||||||
|
)
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||||
|
new_decoder_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_zigzag_ring_attn_monkey_patch_mistral():
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = (
|
||||||
|
new_flash_attn_forward
|
||||||
|
)
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = (
|
||||||
|
new_decoder_forward
|
||||||
|
)
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def extract_local(value, rank, world_size, device, dim=1):
|
||||||
|
value_chunks = value.chunk(2 * world_size, dim=dim)
|
||||||
|
local_value = torch.cat(
|
||||||
|
[value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim
|
||||||
|
)
|
||||||
|
return local_value.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_zigzag_ring_attn_inputs(
|
||||||
|
input_ids, position_ids, target_ids, rank, world_size, device
|
||||||
|
):
|
||||||
|
local_input_ids = extract_local(
|
||||||
|
input_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
local_position_ids = extract_local(
|
||||||
|
position_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
if target_ids is not None:
|
||||||
|
local_target_ids = extract_local(
|
||||||
|
target_ids,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
local_target_ids = None
|
||||||
|
return {
|
||||||
|
"local_input_ids": local_input_ids,
|
||||||
|
"local_position_ids": local_position_ids,
|
||||||
|
"local_target_ids": local_target_ids,
|
||||||
|
}
|
||||||
123
src/axolotl/monkeypatch/sequence_parallel.py
Normal file
123
src/axolotl/monkeypatch/sequence_parallel.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Utilities for sequence parallelism implementation.
|
||||||
|
|
||||||
|
Modified from:
|
||||||
|
https://github.com/Qihoo360/360-LLaMA-Factory/blob/f295a5760cceebe069fb5b975813d2c945598acb/src/llamafactory/model/model_utils/sequence_parallel.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
import transformers
|
||||||
|
import transformers.modeling_attn_mask_utils
|
||||||
|
from ring_flash_attn import (
|
||||||
|
ring_flash_attn_func,
|
||||||
|
stripe_flash_attn_func,
|
||||||
|
zigzag_ring_flash_attn_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ring_flash_attn_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=0,
|
||||||
|
sliding_window=None,
|
||||||
|
is_causal=True,
|
||||||
|
group=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
attn_output = ring_flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, causal=is_causal, group=group
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def zigzag_flash_attn_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=0,
|
||||||
|
sliding_window=None,
|
||||||
|
is_causal=True,
|
||||||
|
group=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
attn_output = zigzag_ring_flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, causal=is_causal, group=group
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def stripe_flash_attn_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=0,
|
||||||
|
sliding_window=None,
|
||||||
|
is_causal=True,
|
||||||
|
group=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
attn_output = stripe_flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, causal=is_causal, group=group
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def init_sp_group(sp_size):
|
||||||
|
assert dist.is_initialized()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
assert (
|
||||||
|
world_size % sp_size == 0
|
||||||
|
), "Total number of GPUs must be a multiple of sequence_parallel_size."
|
||||||
|
|
||||||
|
sp_group_num = world_size // sp_size
|
||||||
|
sp_ranks_list = [
|
||||||
|
list(range(i * sp_size, i * sp_size + sp_size)) for i in range(sp_group_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
sp_groups = [dist.new_group(sp_ranks_this) for sp_ranks_this in sp_ranks_list]
|
||||||
|
|
||||||
|
global_rank_this = dist.get_rank()
|
||||||
|
sp_idx = global_rank_this // sp_size
|
||||||
|
return sp_groups[sp_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sequence_parallel(cfg):
|
||||||
|
if cfg.sequence_parallel_size == 1:
|
||||||
|
return None # no sequence parallelism
|
||||||
|
|
||||||
|
# init sequence-parallel groups here
|
||||||
|
group_this = init_sp_group(cfg.sequence_parallel_size)
|
||||||
|
|
||||||
|
if cfg.sequence_parallel_mode == "ring":
|
||||||
|
new_flash_attention_forward = partial(ring_flash_attn_forward, group=group_this)
|
||||||
|
elif cfg.sequence_parallel_mode == "zigzag-ring":
|
||||||
|
new_flash_attention_forward = partial(
|
||||||
|
zigzag_flash_attn_forward, group=group_this
|
||||||
|
)
|
||||||
|
elif cfg.sequence_parallel_mode == "stripe":
|
||||||
|
new_flash_attention_forward = partial(
|
||||||
|
stripe_flash_attn_forward, group=group_this
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Other sequence parallel modes are to be implemented."
|
||||||
|
)
|
||||||
|
|
||||||
|
# monkey patching
|
||||||
|
transformers.modeling_flash_attention_utils._flash_attention_forward = (
|
||||||
|
new_flash_attention_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
return group_this
|
||||||
@@ -547,6 +547,20 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.sequence_parallel_size > 1:
|
||||||
|
from axolotl.integrations.easy_context import (
|
||||||
|
apply_seq_parallel_monkey_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
method = self.cfg.sequence_parallel_mode
|
||||||
|
model_type = self.cfg.model_type
|
||||||
|
|
||||||
|
# Apply the monkey patch
|
||||||
|
apply_seq_parallel_monkey_patch(method, model_type)
|
||||||
|
|
||||||
|
# Ensure flash attention is enabled when loading the model
|
||||||
|
self.cfg.attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
|
|||||||
Reference in New Issue
Block a user