cleanup
This commit is contained in:
@@ -1,96 +0,0 @@
|
||||
import logging
|
||||
|
||||
from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama
|
||||
from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs
|
||||
from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama
|
||||
from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs
|
||||
from .usp.monkey_patch import apply_usp_attn_monkey_patch_llama
|
||||
from .usp.prepare_inputs import prepare_usp_attn_inputs
|
||||
from .zigzag_ring_attn.monkey_patch import (
|
||||
apply_zigzag_ring_attn_monkey_patch_llama,
|
||||
apply_zigzag_ring_attn_monkey_patch_mistral,
|
||||
)
|
||||
from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_seq_parallel_inputs(
|
||||
seq_algo,
|
||||
input_ids,
|
||||
position_ids,
|
||||
target_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if seq_algo == "zigzag_ring_attn":
|
||||
return prepare_zigzag_ring_attn_inputs(
|
||||
input_ids, position_ids, target_ids, rank, world_size, device
|
||||
)
|
||||
elif seq_algo == "dist_flash_attn":
|
||||
return prepare_dist_flash_attn_inputs(
|
||||
input_ids, position_ids, target_ids, rank, world_size, device
|
||||
)
|
||||
elif seq_algo == "ulysses_attn":
|
||||
return prepare_ulysses_attn_inputs(
|
||||
input_ids, position_ids, target_ids, rank, world_size, device
|
||||
)
|
||||
elif seq_algo == "usp_attn":
|
||||
ring_degree = kwargs.get("ring_degree", 1)
|
||||
ulysses_degree = world_size // ring_degree
|
||||
logger.info(
|
||||
f"Applying USP: Ring degree: {ring_degree}, Ulysses degree: {ulysses_degree}"
|
||||
)
|
||||
return prepare_usp_attn_inputs(
|
||||
input_ids,
|
||||
position_ids,
|
||||
target_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
ulysses_degree,
|
||||
ring_degree,
|
||||
)
|
||||
elif seq_algo == "data_parallel":
|
||||
return {
|
||||
"local_input_ids": input_ids.to(device),
|
||||
"local_position_ids": position_ids.to(device),
|
||||
"local_target_ids": target_ids.to(device),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid seq_algo: {seq_algo}")
|
||||
|
||||
|
||||
def apply_seq_parallel_monkey_patch(seq_algo, model):
|
||||
assert seq_algo in [
|
||||
"zigzag_ring_attn",
|
||||
"dist_flash_attn",
|
||||
"ulysses_attn",
|
||||
"data_parallel",
|
||||
"usp_attn",
|
||||
], f"Invalid seq_algo: {seq_algo}"
|
||||
assert model in ["llama", "mistral"], f"Invalid model: {model}"
|
||||
if seq_algo == "data_parallel":
|
||||
return
|
||||
elif seq_algo == "zigzag_ring_attn" and model == "llama":
|
||||
apply_zigzag_ring_attn_monkey_patch_llama()
|
||||
elif seq_algo == "zigzag_ring_attn" and model == "mistral":
|
||||
apply_zigzag_ring_attn_monkey_patch_mistral()
|
||||
elif seq_algo == "dist_flash_attn" and model == "llama":
|
||||
apply_dist_flash_attn_monkey_patch_llama()
|
||||
elif seq_algo == "ulysses_attn" and model == "llama":
|
||||
apply_ulysses_attn_monkey_patch_llama()
|
||||
elif seq_algo == "usp_attn" and model == "llama":
|
||||
apply_usp_attn_monkey_patch_llama()
|
||||
else:
|
||||
raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}")
|
||||
|
||||
|
||||
def prepare_dataloader(seq_algo, dataloader, accelerator):
|
||||
if seq_algo == "data_parallel":
|
||||
return accelerator.prepare(dataloader)
|
||||
else:
|
||||
return dataloader
|
||||
@@ -1,11 +0,0 @@
|
||||
# LightSeq
|
||||
Taken from https://github.com/RulinShao/LightSeq. All credits to the authors.
|
||||
|
||||
```
|
||||
@article{li2023lightseq,
|
||||
title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS},
|
||||
author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao},
|
||||
journal={arXiv preprint arXiv:2310.03294},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
@@ -1,776 +0,0 @@
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import P2POp, batch_isend_irecv, irecv, isend
|
||||
|
||||
# Sequence parallel group that the current rank belongs to.
|
||||
_SEQUENCE_PARALLEL_GROUP = None
|
||||
|
||||
# These values enable us to change the sequence parallel sizes on the fly.
|
||||
_SEQUENCE_PARALLEL_SIZE = None
|
||||
_SEQUENCE_PARALLEL_RANK = None
|
||||
|
||||
# Global buffer for P2P
|
||||
_PEER_Q = None
|
||||
_PEER_K = None
|
||||
_PEER_V = None
|
||||
_PEER_M = None
|
||||
_PEER_L = None
|
||||
_PEER_O = None
|
||||
_PEER_Q_BWD = None
|
||||
_PEER_K_BWD = None
|
||||
_PEER_V_BWD = None
|
||||
_PEER_O_BWD = None
|
||||
|
||||
_DELTA_DQ = None
|
||||
_PEER_L = None
|
||||
_DELTA_DK = None
|
||||
_DELTA_DV = None
|
||||
_DK_DELTA_FROM_PEER = None
|
||||
_DV_DELTA_FROM_PEER = None
|
||||
_PEER_DO = None
|
||||
|
||||
|
||||
_fwd_send_volume = 0
|
||||
_fwd_recv_volume = 0
|
||||
_bwd_send_volume = 0
|
||||
_bwd_recv_volume = 0
|
||||
|
||||
|
||||
def initialize_distributed():
|
||||
if dist.is_initialized():
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
"torch distributed is already initialized, "
|
||||
"skipping initialization ...",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
if int(os.environ["RANK"]) == 0:
|
||||
print("Initializing Torch distributed.")
|
||||
dist.init_process_group(backend="nccl")
|
||||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||
global_world_size = dist.get_world_size()
|
||||
torch.cuda.set_device(dist.get_rank() % local_world_size)
|
||||
|
||||
_initialize_sequence_parallel()
|
||||
|
||||
|
||||
# create_nccl_communicators()
|
||||
|
||||
|
||||
def _initialize_sequence_parallel(sequence_parallel_size=None):
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert (
|
||||
sequence_parallel_size is None
|
||||
), "Multiple sequence parallel group not implemented."
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
|
||||
if sequence_parallel_size is None:
|
||||
sequence_parallel_size = world_size
|
||||
else:
|
||||
assert world_size % sequence_parallel_size == 0
|
||||
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the sequence parallel groups.
|
||||
global _SEQUENCE_PARALLEL_GROUP
|
||||
global _SEQUENCE_PARALLEL_RANK
|
||||
global _SEQUENCE_PARALLEL_SIZE
|
||||
|
||||
assert (
|
||||
_SEQUENCE_PARALLEL_GROUP is None
|
||||
), "sequence parallel group is already initialized"
|
||||
for i in range(num_sequence_parallel_groups):
|
||||
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_SEQUENCE_PARALLEL_GROUP = group
|
||||
_SEQUENCE_PARALLEL_RANK = ranks.index(rank)
|
||||
_SEQUENCE_PARALLEL_SIZE = len(ranks)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print("************ Finish sequence pralell group Initialization. ***********")
|
||||
# _set_global_memory_buffer()
|
||||
|
||||
|
||||
def maybe_get_set_global_memory_buffer(q, k, v, m, l, o):
|
||||
global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O
|
||||
if _PEER_Q is None:
|
||||
try:
|
||||
if get_sequence_parallel_rank() == 0:
|
||||
print("Initializing global memoery buffer.")
|
||||
except:
|
||||
print("Initializing global memoery buffer.")
|
||||
_PEER_Q = [torch.empty_like(q) for _ in range(2)]
|
||||
_PEER_K = [torch.empty_like(k) for _ in range(2)]
|
||||
_PEER_V = [torch.empty_like(v) for _ in range(2)]
|
||||
_PEER_M = [torch.empty_like(m) for _ in range(2)]
|
||||
_PEER_L = [torch.empty_like(l) for _ in range(2)]
|
||||
_PEER_O = [torch.empty_like(o) for _ in range(2)]
|
||||
|
||||
return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O
|
||||
|
||||
|
||||
def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do):
|
||||
global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO
|
||||
if _DELTA_DQ is None:
|
||||
try:
|
||||
if get_sequence_parallel_rank() == 0:
|
||||
print("Initializing global memoery buffer for backward.")
|
||||
except:
|
||||
print("Initializing global memoery buffer for backward.")
|
||||
_DELTA_DQ = [torch.empty_like(dq) for _ in range(2)]
|
||||
_DELTA_DK = [torch.empty_like(dk) for _ in range(2)]
|
||||
_DELTA_DV = [torch.empty_like(dv) for _ in range(2)]
|
||||
_PEER_L = [torch.empty_like(L) for _ in range(2)]
|
||||
|
||||
_DK_DELTA_FROM_PEER = torch.empty_like(dk)
|
||||
_DV_DELTA_FROM_PEER = torch.empty_like(dv)
|
||||
|
||||
# may already be initailized in the forward call.
|
||||
# current forward and backward needs a transpose in q's format
|
||||
_PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)]
|
||||
_PEER_K_BWD = [torch.empty_like(k) for _ in range(2)]
|
||||
_PEER_V_BWD = [torch.empty_like(v) for _ in range(2)]
|
||||
_PEER_O_BWD = [torch.empty_like(o) for _ in range(2)]
|
||||
|
||||
_PEER_DO = [torch.empty_like(do) for _ in range(2)]
|
||||
|
||||
return (
|
||||
_DELTA_DQ,
|
||||
_DELTA_DK,
|
||||
_DELTA_DV,
|
||||
_DK_DELTA_FROM_PEER,
|
||||
_DV_DELTA_FROM_PEER,
|
||||
_PEER_Q_BWD,
|
||||
_PEER_L,
|
||||
_PEER_K_BWD,
|
||||
_PEER_V_BWD,
|
||||
_PEER_O_BWD,
|
||||
_PEER_DO,
|
||||
)
|
||||
|
||||
|
||||
def reset_global_memory_buffer():
|
||||
global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO
|
||||
_PEER_Q = None
|
||||
_PEER_K = None
|
||||
_PEER_V = None
|
||||
_PEER_M = None
|
||||
_PEER_L = None
|
||||
_PEER_O = None
|
||||
|
||||
_DELTA_DQ = None
|
||||
_PEER_L = None
|
||||
_DELTA_DK = None
|
||||
_DELTA_DV = None
|
||||
_DK_DELTA_FROM_PEER = None
|
||||
_DV_DELTA_FROM_PEER = None
|
||||
_PEER_DO = None
|
||||
|
||||
|
||||
# Pytorch defers the creation of nccl communicators to the first P2P call,
|
||||
# We manually create them so the first isend does not hang without an irecv.
|
||||
# reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138
|
||||
# Only support even number of GPUs.
|
||||
def create_nccl_communicators():
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_group = get_sequence_parallel_group()
|
||||
|
||||
empty_tensor = torch.empty(
|
||||
1,
|
||||
).cuda()
|
||||
empty_tensor_2 = torch.empty(
|
||||
1,
|
||||
).cuda()
|
||||
if torch.distributed.get_rank() % 2 == 0:
|
||||
# sender
|
||||
op1 = P2POp(
|
||||
op=isend,
|
||||
tensor=torch.empty(
|
||||
1,
|
||||
).cuda(),
|
||||
peer=seq_rank + 1,
|
||||
group=seq_group,
|
||||
)
|
||||
op2 = P2POp(
|
||||
op=irecv,
|
||||
tensor=torch.empty(
|
||||
1,
|
||||
).cuda(),
|
||||
peer=seq_rank + 1,
|
||||
group=seq_group,
|
||||
)
|
||||
# req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group)
|
||||
dist.batch_isend_irecv([op1, op2])
|
||||
else:
|
||||
# receiver
|
||||
op1 = P2POp(
|
||||
op=irecv,
|
||||
tensor=torch.empty(
|
||||
1,
|
||||
).cuda(),
|
||||
peer=seq_rank - 1,
|
||||
group=seq_group,
|
||||
)
|
||||
op2 = P2POp(
|
||||
op=isend,
|
||||
tensor=torch.empty(
|
||||
1,
|
||||
).cuda(),
|
||||
peer=seq_rank - 1,
|
||||
group=seq_group,
|
||||
)
|
||||
# req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group)
|
||||
handles = dist.batch_isend_irecv([op1, op2])
|
||||
# req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group)
|
||||
dist.all_reduce(empty_tensor, group=seq_group)
|
||||
|
||||
|
||||
def get_sequence_parallel_group():
|
||||
"""Get the sequence parallel group the caller rank belongs to."""
|
||||
# global _SEQUENCE_PARALLEL_GROUP
|
||||
assert (
|
||||
_SEQUENCE_PARALLEL_GROUP is not None
|
||||
), "sequence parallel group is not initialized"
|
||||
return _SEQUENCE_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_sequence_parallel_rank():
|
||||
"""Return my rank for the sequence parallel group."""
|
||||
global _SEQUENCE_PARALLEL_RANK
|
||||
if _SEQUENCE_PARALLEL_RANK is not None:
|
||||
return _SEQUENCE_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_sequence_parallel_group())
|
||||
|
||||
|
||||
def get_sequence_parallel_size():
|
||||
"""Return my rank for the sequence parallel group."""
|
||||
global _SEQUENCE_PARALLEL_SIZE
|
||||
if _SEQUENCE_PARALLEL_SIZE is not None:
|
||||
return _SEQUENCE_PARALLEL_SIZE
|
||||
return torch.distributed.get_world_size(group=get_sequence_parallel_group())
|
||||
|
||||
|
||||
def destroy_sequence_parallel():
|
||||
"""Set the groups to none."""
|
||||
global _SEQUENCE_PARALLEL_GROUP
|
||||
_SEQUENCE_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
# whether this is the last time the kernel being called
|
||||
def is_last_time(time_step):
|
||||
# e.g. on a 8-GPU setup:
|
||||
# R=0: 0
|
||||
# R=1: 1
|
||||
# R=2: 2
|
||||
# R=3: 3
|
||||
# R=4: 4, 5, 6, 7
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_world_size = get_sequence_parallel_size()
|
||||
if seq_rank <= seq_world_size // 2: # no one helps these ranks
|
||||
rank_finish_time = seq_rank
|
||||
else:
|
||||
rank_finish_time = seq_world_size // 2
|
||||
return rank_finish_time == time_step
|
||||
|
||||
|
||||
# Whether the current time step is computing for local q
|
||||
def is_compute_for_local_query(time_step):
|
||||
# R=3,4,5,6,7: Yes
|
||||
# R=0: 0
|
||||
# R=1: 0, 1
|
||||
# R=2: 0, 1, 2
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_world_size = get_sequence_parallel_size()
|
||||
if seq_rank >= min(seq_world_size // 2, time_step):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Whether the current time step is idle
|
||||
def is_idle(time_step):
|
||||
# 0, 1, 2, 3: 4
|
||||
# 4, 5, 6, 7: No
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_world_size = get_sequence_parallel_size()
|
||||
|
||||
if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Whether the current time step needs to synchronize with a remote computed result
|
||||
def is_sync_from_remote(time_step):
|
||||
# R=0, 1, 2, 3, 4: No
|
||||
# R=5: 4
|
||||
# R=6: 3, 4
|
||||
# R=7: 2, 3, 4
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_world_size = get_sequence_parallel_size()
|
||||
if seq_rank > max(seq_world_size // 2, seq_world_size - time_step):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def maybe_send_recv_fwd_qkvo(
|
||||
q: torch.Tensor,
|
||||
peer_q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
peer_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
peer_v: torch.Tensor,
|
||||
o_stats: list, # peer_o_stats: list,
|
||||
time_step: int,
|
||||
comm_mode,
|
||||
debug=False,
|
||||
) -> torch.Tensor:
|
||||
seq_group = get_sequence_parallel_group()
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_world_size = get_sequence_parallel_size()
|
||||
|
||||
# Handles for operations that actually need to be wait before going to the next iteration.
|
||||
# For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler;
|
||||
all_handles = []
|
||||
# KV logic: different than older version, every rank to send/recv its own kv,
|
||||
# to balance communication. In a balanced communication, every step each rank
|
||||
# should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when
|
||||
# time step > 0, should send its own kv and send/recv qo. In the older version,
|
||||
# rank 0 does not send its kv, and rely on a later rank to pass it, where the
|
||||
# later rank has to (1) receive kv, send rank 0's kv and send/recv qo.
|
||||
# Q (load balancing) logic: semantically, this will be "%" world size, so
|
||||
# the same send/recv rank as KV. Note: Only support even number of machines.
|
||||
# O (load balancing) logic: rank 0 sends result to rank 7 at time 1.
|
||||
# It get delayed for one time step, and thus has different maybe_send/recv_rank.
|
||||
# Use (time_step + 1) to easily convert to synchornize version.
|
||||
maybe_send_rank = seq_rank + (time_step + 1)
|
||||
maybe_recv_rank = seq_rank - (time_step + 1)
|
||||
|
||||
if debug:
|
||||
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||
_debug_send = _fwd_send_volume
|
||||
_debug_recv = _fwd_recv_volume
|
||||
|
||||
if maybe_send_rank >= seq_world_size:
|
||||
# send q, no one needs to do remote computation in the last time step
|
||||
if time_step < (seq_world_size // 2 - 1):
|
||||
# print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)")
|
||||
# q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group))
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=q,
|
||||
peer=maybe_send_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_fwd_send_volume += torch.numel(q) * q.element_size()
|
||||
else:
|
||||
# send kv
|
||||
# print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)")
|
||||
# kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group))
|
||||
# kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group))
|
||||
all_handles.append(
|
||||
P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)
|
||||
)
|
||||
if debug:
|
||||
_fwd_send_volume += torch.numel(k) * k.element_size()
|
||||
_fwd_send_volume += torch.numel(v) * v.element_size()
|
||||
|
||||
if maybe_recv_rank < 0:
|
||||
# recv q, no one needs to do remote computation in the last time step
|
||||
if time_step < (seq_world_size // 2 - 1):
|
||||
# print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)")
|
||||
# q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group))
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=peer_q,
|
||||
peer=maybe_recv_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size()
|
||||
else:
|
||||
# recv kv
|
||||
# print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)")
|
||||
# kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group))
|
||||
# kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group))
|
||||
all_handles.append(
|
||||
P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)
|
||||
)
|
||||
if debug:
|
||||
_fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size()
|
||||
_fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size()
|
||||
|
||||
maybe_send_rank_o = seq_rank - (time_step - 1)
|
||||
maybe_recv_rank_o = seq_rank + (time_step - 1)
|
||||
if maybe_send_rank_o < 0 and time_step > 1:
|
||||
for t in o_stats:
|
||||
# print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)")
|
||||
# o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group))
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=t,
|
||||
peer=maybe_send_rank_o % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_fwd_send_volume += torch.numel(t) * t.element_size()
|
||||
if maybe_recv_rank_o >= seq_world_size and time_step > 1:
|
||||
for t in o_stats:
|
||||
# print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)")
|
||||
# o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group))
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=t,
|
||||
peer=maybe_recv_rank_o % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_fwd_recv_volume += torch.numel(t) * t.element_size()
|
||||
|
||||
# reqs = []
|
||||
|
||||
if debug:
|
||||
if seq_rank in [0, 8]:
|
||||
print(
|
||||
f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB"
|
||||
)
|
||||
# return reqs
|
||||
all_reqs = launch_async_handles(all_handles, comm_mode)
|
||||
return [all_reqs]
|
||||
|
||||
|
||||
# delta: may be you are using it for your local compute or as a distributed buffer to send to others
|
||||
# .. Sorry for the bad naming..
|
||||
def maybe_send_recv_bwd_qkvo(
|
||||
dq_delta: torch.Tensor,
|
||||
dk_delta: torch.Tensor,
|
||||
dv_delta: torch.Tensor,
|
||||
dk_delta_from_peer: torch.Tensor,
|
||||
dv_delta_from_peer: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
peer_q: torch.Tensor,
|
||||
L: torch.Tensor,
|
||||
peer_L: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
peer_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
peer_v: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
peer_o: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
peer_do: torch.Tensor,
|
||||
time_step: int,
|
||||
comm_mode,
|
||||
debug=False,
|
||||
):
|
||||
seq_group = get_sequence_parallel_group()
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_world_size = get_sequence_parallel_size()
|
||||
|
||||
all_handles = []
|
||||
maybe_send_rank = seq_rank + (time_step + 1)
|
||||
maybe_recv_rank = seq_rank - (time_step + 1)
|
||||
|
||||
if debug:
|
||||
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||
|
||||
if maybe_send_rank >= seq_world_size:
|
||||
# send q, no one needs to do remote computation in the last time step
|
||||
if time_step < (seq_world_size // 2 - 1):
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=q,
|
||||
peer=maybe_send_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=L,
|
||||
peer=maybe_send_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=o,
|
||||
peer=maybe_send_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=do,
|
||||
peer=maybe_send_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_bwd_send_volume += torch.numel(q) * q.element_size()
|
||||
_bwd_send_volume += torch.numel(L) * L.element_size()
|
||||
_bwd_send_volume += torch.numel(o) * o.element_size()
|
||||
_bwd_send_volume += torch.numel(do) * do.element_size()
|
||||
else:
|
||||
# send kv
|
||||
all_handles.append(
|
||||
P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)
|
||||
)
|
||||
if debug:
|
||||
_bwd_send_volume += torch.numel(k) * k.element_size()
|
||||
_bwd_send_volume += torch.numel(v) * v.element_size()
|
||||
|
||||
if maybe_recv_rank < 0:
|
||||
# recv q, no one needs to do remote computation in the last time step
|
||||
if time_step < (seq_world_size // 2 - 1):
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=peer_q,
|
||||
peer=maybe_recv_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=peer_L,
|
||||
peer=maybe_recv_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=peer_o,
|
||||
peer=maybe_recv_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=peer_do,
|
||||
peer=maybe_recv_rank % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size()
|
||||
_bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size()
|
||||
_bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size()
|
||||
_bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size()
|
||||
else:
|
||||
# recv kv
|
||||
all_handles.append(
|
||||
P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)
|
||||
)
|
||||
if debug:
|
||||
_bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size()
|
||||
_bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size()
|
||||
|
||||
# Whether I should update dq, dk and dv after waiting these requests
|
||||
is_update_dq = False
|
||||
is_update_dkv = False
|
||||
|
||||
maybe_send_rank_dqkv = seq_rank - (time_step - 1)
|
||||
maybe_recv_rank_dqkv = seq_rank + (time_step - 1)
|
||||
|
||||
if time_step > 1:
|
||||
if maybe_send_rank_dqkv < 0:
|
||||
# print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}")
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=dq_delta,
|
||||
peer=maybe_send_rank_dqkv % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size()
|
||||
else:
|
||||
# print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}")
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=dk_delta,
|
||||
peer=maybe_send_rank_dqkv,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=isend,
|
||||
tensor=dv_delta,
|
||||
peer=maybe_send_rank_dqkv,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
if debug:
|
||||
_bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size()
|
||||
_bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size()
|
||||
|
||||
if maybe_recv_rank_dqkv >= seq_world_size:
|
||||
# print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}")
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=dq_delta,
|
||||
peer=maybe_recv_rank_dqkv % seq_world_size,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
is_update_dq = True
|
||||
if debug:
|
||||
_bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size()
|
||||
else:
|
||||
# print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}")
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=dk_delta_from_peer,
|
||||
peer=maybe_recv_rank_dqkv,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(
|
||||
op=irecv,
|
||||
tensor=dv_delta_from_peer,
|
||||
peer=maybe_recv_rank_dqkv,
|
||||
group=seq_group,
|
||||
)
|
||||
)
|
||||
is_update_dkv = True
|
||||
if debug:
|
||||
_bwd_recv_volume += (
|
||||
torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size()
|
||||
)
|
||||
_bwd_recv_volume += (
|
||||
torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size()
|
||||
)
|
||||
|
||||
# return [], is_update_dq, is_update_dkv
|
||||
all_reqs = launch_async_handles(all_handles, comm_mode)
|
||||
return [all_reqs], is_update_dq, is_update_dkv
|
||||
|
||||
|
||||
def maybe_send_recv_bwd_last_dkv(
|
||||
dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False
|
||||
):
|
||||
is_update_last_dkv = False
|
||||
|
||||
seq_group = get_sequence_parallel_group()
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
seq_world_size = get_sequence_parallel_size()
|
||||
|
||||
if seq_world_size == 1:
|
||||
return [], is_update_last_dkv
|
||||
|
||||
all_handles = []
|
||||
|
||||
if debug:
|
||||
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||
|
||||
if time_step == seq_world_size // 2:
|
||||
maybe_send_rank = seq_rank - time_step
|
||||
maybe_recv_rank = seq_rank + time_step
|
||||
|
||||
assert (maybe_send_rank >= 0) ^ (
|
||||
maybe_recv_rank < seq_world_size
|
||||
), "R={seq_rank} should be either sending or receiving dkv in the last time step."
|
||||
|
||||
if maybe_send_rank >= 0:
|
||||
# print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}")
|
||||
all_handles.append(
|
||||
P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group)
|
||||
)
|
||||
if debug:
|
||||
_bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size()
|
||||
_bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size()
|
||||
if maybe_recv_rank < seq_world_size:
|
||||
# print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}")
|
||||
all_handles.append(
|
||||
P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group)
|
||||
)
|
||||
all_handles.append(
|
||||
P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group)
|
||||
)
|
||||
if debug:
|
||||
_bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size()
|
||||
_bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size()
|
||||
is_update_last_dkv = True
|
||||
|
||||
# return [], is_update_last_dkv
|
||||
all_reqs = launch_async_handles(all_handles, comm_mode)
|
||||
|
||||
return [all_reqs], is_update_last_dkv
|
||||
|
||||
|
||||
def print_and_reset_comm_stats():
|
||||
seq_rank = get_sequence_parallel_rank()
|
||||
|
||||
global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume
|
||||
_fwd_send_volume *= 1e-9
|
||||
_fwd_recv_volume *= 1e-9
|
||||
_bwd_send_volume *= 1e-9
|
||||
_bwd_recv_volume *= 1e-9
|
||||
|
||||
print(
|
||||
f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB."
|
||||
)
|
||||
_fwd_send_volume = 0
|
||||
_fwd_recv_volume = 0
|
||||
_bwd_send_volume = 0
|
||||
_bwd_recv_volume = 0
|
||||
|
||||
|
||||
def launch_async_handles(handles, comm_mode):
|
||||
global _args
|
||||
if comm_mode == "nocomm":
|
||||
# print("skipping communication for ablation")
|
||||
return []
|
||||
if len(handles) > 0:
|
||||
return dist.batch_isend_irecv(handles)
|
||||
return []
|
||||
|
||||
|
||||
def wait_async_handles(reqs):
|
||||
if len(reqs) > 0:
|
||||
for req in reqs:
|
||||
for r in req:
|
||||
r.wait()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,754 +0,0 @@
|
||||
"""
|
||||
Materialization-aware gradient checkpointing monkey patch.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch.utils.checkpoint import (
|
||||
_get_autocast_kwargs,
|
||||
check_backward_validity,
|
||||
detach_variable,
|
||||
get_device_states,
|
||||
set_device_states,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
BaseModelOutputWithPast,
|
||||
LlamaDecoderLayer,
|
||||
LlamaModel,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
from .async_communication import initialize_distributed
|
||||
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
|
||||
|
||||
# define a global buffer to save flash attention outputs
|
||||
# it's called global because it saves the outputs for all layers
|
||||
global_flash_attn_out_buffer = None
|
||||
|
||||
# define a local buffer to save recomputed qkv
|
||||
# it's called local because it's a temporary buffer which will be updated across layers
|
||||
local_res_grad_buffer = None
|
||||
|
||||
# hooks for the gradients of residual
|
||||
global_hooks = []
|
||||
|
||||
|
||||
def init_flash_attn_buffers(num_layers):
|
||||
# update the global buffer according to number of layers
|
||||
global global_flash_attn_out_buffer
|
||||
global_flash_attn_out_buffer = [None] * num_layers
|
||||
|
||||
|
||||
def clean_hook():
|
||||
# Remove all hooks in the global buffer
|
||||
for hook in global_hooks:
|
||||
hook.remove()
|
||||
# Clear the global buffer
|
||||
global_hooks.clear()
|
||||
|
||||
|
||||
def clear_all_buffers_at_the_end_of_training():
|
||||
# call it at the end of training
|
||||
global lobal_flash_attn_out_buffer
|
||||
global_flash_attn_out_buffer = None
|
||||
global local_res_grad_buffer
|
||||
local_res_grad_buffer = None
|
||||
clean_hook()
|
||||
|
||||
|
||||
def save_flash_attn_out_to_global_buffer(idx, out):
|
||||
global global_flash_attn_out_buffer
|
||||
global_flash_attn_out_buffer[idx] = out
|
||||
|
||||
|
||||
def get_flash_attn_out_from_global_buffer(idx):
|
||||
global global_flash_attn_out_buffer
|
||||
return global_flash_attn_out_buffer[idx]
|
||||
|
||||
|
||||
def free_flash_attn_out_buffer(idx):
|
||||
global global_flash_attn_out_buffer
|
||||
global_flash_attn_out_buffer[idx] = None
|
||||
|
||||
|
||||
def write_gradient_to_flash_attn_out(idx, grad):
|
||||
global global_flash_attn_out_buffer
|
||||
global_flash_attn_out_buffer[idx].grad = grad
|
||||
|
||||
|
||||
def save_res_grad_hook(grad):
|
||||
global local_res_grad_buffer
|
||||
local_res_grad_buffer = grad
|
||||
|
||||
|
||||
def load_and_add_res_grad_hook(grad):
|
||||
grad += get_res_grad_from_local_buffer()
|
||||
|
||||
|
||||
def get_res_grad_from_local_buffer():
|
||||
global local_res_grad_buffer
|
||||
assert local_res_grad_buffer is not None
|
||||
return local_res_grad_buffer
|
||||
|
||||
|
||||
class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function):
|
||||
"""Avoid doing twice flash attention forward during checkpointed backward.
|
||||
args:
|
||||
hidden_states, # i.e., flash attention output which is saved in global buffer.
|
||||
attention_mask,
|
||||
position_ids,
|
||||
residual, # the gradient of residual is saved in local buffer to pass across ckpt layers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, layer_idx, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.layer_idx = layer_idx
|
||||
ctx.preserve_rng_state = preserve_rng_state
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
||||
if preserve_rng_state:
|
||||
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||
# Don't eagerly initialize the cuda context by accident.
|
||||
# (If the user intends that the context is initialized later, within their
|
||||
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||
# we have no way to anticipate this will happen before we run the function.)
|
||||
ctx.had_cuda_in_fwd = False
|
||||
if torch.cuda._initialized:
|
||||
ctx.had_cuda_in_fwd = True
|
||||
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
||||
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
for i, arg in enumerate(args):
|
||||
if i == 0 and ctx.layer_idx != 0:
|
||||
# flash attention output is saved to the global buffer during forward
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
if torch.is_tensor(arg):
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
with torch.no_grad():
|
||||
q, k, v, residual = run_function(*args)
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
# lightseq version
|
||||
_, _, _, out, softmax_lse = _lightseq_forward(
|
||||
q, k, v, True, softmax_scale, comm_mode="lightseq"
|
||||
)
|
||||
rng_state = None
|
||||
|
||||
# save flash attention output to global buffer
|
||||
save_flash_attn_out_to_global_buffer(ctx.layer_idx, out)
|
||||
tensor_inputs += [softmax_lse]
|
||||
ctx.softmax_scale = softmax_scale
|
||||
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
|
||||
return out, residual
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||
" argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
tensors = ctx.saved_tensors
|
||||
tensors, softmax_lse = tensors[:-1], tensors[-1]
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
# Fill the flash attention output first
|
||||
if ctx.layer_idx > 0:
|
||||
# inputs[0] should be flash attention output
|
||||
inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx - 1)
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
inputs[idx] = tensors[i]
|
||||
|
||||
# Stash the surrounding rng state, and mimic the state that was
|
||||
# present at this time during forward. Restore the surrounding state
|
||||
# when we're done.
|
||||
rng_devices = []
|
||||
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
||||
rng_devices = ctx.fwd_gpu_devices
|
||||
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
||||
if ctx.preserve_rng_state:
|
||||
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||
if ctx.had_cuda_in_fwd:
|
||||
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(
|
||||
**ctx.gpu_autocast_kwargs
|
||||
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||
# Stop recomputation before flash attention
|
||||
# It is unecessary to run recomputation for flash attn
|
||||
q, k, v, residual = ctx.run_function(*detached_inputs)
|
||||
|
||||
# run backward() with only tensor that requires grad
|
||||
# run flash attention backward first:
|
||||
# get 'dout' from auto_grad inputs
|
||||
# get 'out' from global buffer
|
||||
# get 'qkv' from the recomputed tensors
|
||||
# dq = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
# dk = torch.empty(k.shape, dtype=q.dtype, device=q.device)
|
||||
# dv = torch.empty(v.shape, dtype=q.dtype, device=q.device)
|
||||
out = get_flash_attn_out_from_global_buffer(ctx.layer_idx)
|
||||
# todo get dout
|
||||
dout = args[0]
|
||||
|
||||
# lightseq version
|
||||
dq, dk, dv = _lightseq_backward(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
ctx.softmax_scale,
|
||||
comm_mode="lightseq",
|
||||
backward_engine="flash",
|
||||
)
|
||||
# dqkv = torch.stack([dq, dk, dv])
|
||||
|
||||
# run backward for the part before flash attention
|
||||
# qkv.backward(dqkv)
|
||||
torch.autograd.backward([q, k, v], [dq, dk, dv])
|
||||
|
||||
grads = tuple(
|
||||
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs
|
||||
)
|
||||
|
||||
# write flash attention output gradients to buffer
|
||||
if ctx.layer_idx > 0:
|
||||
write_gradient_to_flash_attn_out(ctx.layer_idx - 1, detached_inputs[0].grad)
|
||||
|
||||
return (None, None, None) + grads
|
||||
|
||||
|
||||
def checkpoint_end_with_flash_attention(
|
||||
function, layer_idx, *args, use_reentrant: bool = True, **kwargs
|
||||
):
|
||||
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
||||
preserve = kwargs.pop("preserve_rng_state", True)
|
||||
if kwargs and use_reentrant:
|
||||
raise ValueError(
|
||||
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||||
)
|
||||
|
||||
return CheckpointFunctionEndWithFlashAttention.apply(
|
||||
function, layer_idx, preserve, *args
|
||||
)
|
||||
|
||||
|
||||
class CheckpointFunctionLastModule(torch.autograd.Function):
|
||||
"""
|
||||
for the last ffn layer after flash attention, modifications include:
|
||||
write the gradients wrt flash attention output and residual to the global buffer.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.preserve_rng_state = preserve_rng_state
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
||||
if preserve_rng_state:
|
||||
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||
# Don't eagerly initialize the cuda context by accident.
|
||||
# (If the user intends that the context is initialized later, within their
|
||||
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||
# we have no way to anticipate this will happen before we run the function.)
|
||||
ctx.had_cuda_in_fwd = False
|
||||
if torch.cuda._initialized:
|
||||
ctx.had_cuda_in_fwd = True
|
||||
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
||||
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
|
||||
assert torch.is_tensor(
|
||||
args[0]
|
||||
), "assuming the first tensor is the flash attention output"
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg) and i == 0:
|
||||
# flash attn output has been saved to global buffer
|
||||
ctx.inputs.append(None)
|
||||
elif torch.is_tensor(arg):
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||
" argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
# Fill the flash attention output first
|
||||
# inputs[0] should be flash attention output
|
||||
inputs[0] = get_flash_attn_out_from_global_buffer(-1)
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
inputs[idx] = tensors[i]
|
||||
|
||||
# Stash the surrounding rng state, and mimic the state that was
|
||||
# present at this time during forward. Restore the surrounding state
|
||||
# when we're done.
|
||||
rng_devices = []
|
||||
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
||||
rng_devices = ctx.fwd_gpu_devices
|
||||
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
||||
if ctx.preserve_rng_state:
|
||||
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||
if ctx.had_cuda_in_fwd:
|
||||
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(
|
||||
**ctx.gpu_autocast_kwargs
|
||||
), torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
|
||||
# run backward() with only tensor that requires grad
|
||||
outputs_with_grad = []
|
||||
args_with_grad = []
|
||||
for i in range(len(outputs)):
|
||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError(
|
||||
"none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary"
|
||||
)
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(
|
||||
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs
|
||||
)
|
||||
|
||||
# write flash attention output gradients to buffer
|
||||
write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad)
|
||||
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs):
|
||||
preserve = kwargs.pop("preserve_rng_state", True)
|
||||
if kwargs and use_reentrant:
|
||||
raise ValueError(
|
||||
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
||||
)
|
||||
|
||||
return CheckpointFunctionLastModule.apply(function, preserve, *args)
|
||||
|
||||
|
||||
def llama_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
compute_attn_only: Optional[bool] = False,
|
||||
compute_ffn_only: Optional[bool] = False,
|
||||
residual: Optional[bool] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
assert compute_ffn_only or compute_attn_only
|
||||
|
||||
if compute_attn_only:
|
||||
residual = hidden_states
|
||||
|
||||
if residual.requires_grad:
|
||||
# register a hook to add the gradient of residual
|
||||
# from next checkpoint layer when doing recomputation
|
||||
hook = residual.register_hook(load_and_add_res_grad_hook)
|
||||
global_hooks.append(hook)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Flash Attention
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
try:
|
||||
query_states = (
|
||||
self.self_attn.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.self_attn.k_proj(hidden_states)
|
||||
.view(
|
||||
bsz,
|
||||
q_len,
|
||||
self.self_attn.num_key_value_heads,
|
||||
self.self_attn.head_dim,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.self_attn.v_proj(hidden_states)
|
||||
.view(
|
||||
bsz,
|
||||
q_len,
|
||||
self.self_attn.num_key_value_heads,
|
||||
self.self_attn.head_dim,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
except:
|
||||
# old transformers versions don't support num_key_value_heads
|
||||
query_states = (
|
||||
self.self_attn.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.self_attn.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.self_attn.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
assert past_key_value is None, "past_key_value is not supported"
|
||||
|
||||
cos, sin = self.self_attn.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
assert not output_attentions, "output_attentions is not supported"
|
||||
assert not use_cache, "use_cache is not supported"
|
||||
return (
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
residual,
|
||||
)
|
||||
|
||||
elif compute_ffn_only:
|
||||
hidden_states = self.self_attn.o_proj(
|
||||
rearrange(hidden_states, "b h s d -> b s (h d)")
|
||||
)
|
||||
# Need to add residual here to make sure checkpoint is right after attention
|
||||
if residual.requires_grad:
|
||||
# save the gradient of residual to the local buffer
|
||||
# collect the hooks which should be removed after backward to avoid memory leak
|
||||
hook = residual.register_hook(save_res_grad_hook)
|
||||
global_hooks.append(hook)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
else:
|
||||
raise AttributeError
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
assert cache_position is None, "cache_position is not supported"
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
attention_mask = None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
try:
|
||||
logger.warning_once("***** Using fast gradient checkpointing... *****")
|
||||
except:
|
||||
pass
|
||||
# initialize the global buffer
|
||||
init_flash_attn_buffers(len(self.layers))
|
||||
|
||||
if use_cache:
|
||||
try:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
except:
|
||||
pass
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# apply flash-attention friendly gradient checkpointing
|
||||
if self.gradient_checkpointing and self.training:
|
||||
for idx in range(len(self.layers) + 1):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[idx] if past_key_values is not None else None
|
||||
)
|
||||
|
||||
def forward_first_attn_module(module):
|
||||
def custom_forward(*inputs):
|
||||
hidden_states, attention_mask, position_ids, _ = inputs
|
||||
# None for past_key_value
|
||||
return module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
compute_attn_only=True,
|
||||
)
|
||||
|
||||
return custom_forward
|
||||
|
||||
def forward_ffn_attn_layer(module1, module2):
|
||||
def custom_forward(*inputs):
|
||||
hidden_states, attention_mask, position_ids, residual = inputs
|
||||
# None for past_key_value
|
||||
layer_outputs = module1(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
compute_ffn_only=True,
|
||||
residual=residual,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
return module2(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
compute_attn_only=True,
|
||||
)
|
||||
|
||||
return custom_forward
|
||||
|
||||
def forward_last_ffn_module(module):
|
||||
def custom_forward(*inputs):
|
||||
hidden_states, attention_mask, position_ids, residual = inputs
|
||||
# None for past_key_value
|
||||
return module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
compute_ffn_only=True,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if idx == 0:
|
||||
layer_outputs = checkpoint_end_with_flash_attention(
|
||||
forward_first_attn_module(self.layers[0]),
|
||||
idx,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
hidden_states, residual = layer_outputs[0], layer_outputs[-1]
|
||||
elif idx == len(self.layers):
|
||||
layer_outputs = checkpoint_last_module(
|
||||
forward_last_ffn_module(self.layers[-1]),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
residual,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
else:
|
||||
layer_outputs = checkpoint_end_with_flash_attention(
|
||||
forward_ffn_attn_layer(self.layers[idx - 1], self.layers[idx]),
|
||||
idx,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
residual,
|
||||
)
|
||||
hidden_states, residual = layer_outputs[0], layer_outputs[-1]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
else:
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[idx] if past_key_values is not None else None
|
||||
)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def apply_dist_flash_attn_monkey_patch_llama():
|
||||
initialize_distributed()
|
||||
|
||||
LlamaModel.forward = forward
|
||||
LlamaDecoderLayer.forward = llama_layer_forward
|
||||
@@ -1,34 +0,0 @@
|
||||
def extract_local(value, rank, world_size, device, dim=1):
|
||||
value_local = value.chunk(world_size, dim=dim)[rank]
|
||||
return value_local.to(device)
|
||||
|
||||
|
||||
def prepare_dist_flash_attn_inputs(
|
||||
input_ids, position_ids, target_ids, rank, world_size, device
|
||||
):
|
||||
local_input_ids = extract_local(
|
||||
input_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
local_position_ids = extract_local(
|
||||
position_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
if target_ids is not None:
|
||||
local_target_ids = extract_local(
|
||||
target_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
else:
|
||||
local_target_ids = None
|
||||
return {
|
||||
"local_input_ids": local_input_ids,
|
||||
"local_position_ids": local_position_ids,
|
||||
"local_target_ids": local_target_ids,
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
|
||||
try:
|
||||
from yunchang.ulysses import UlyssesAttention
|
||||
|
||||
ulysses_attn = UlyssesAttention()
|
||||
except:
|
||||
print(
|
||||
"If you want to use the UlyssesAttention class, please install the yunchang package."
|
||||
)
|
||||
ulysses_attn = None
|
||||
|
||||
|
||||
def new_flash_attn_forward(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=0.0,
|
||||
softmax_scale=None,
|
||||
use_sliding_windows=False,
|
||||
):
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
assert attention_mask is None
|
||||
assert causal is True
|
||||
assert use_sliding_windows is False
|
||||
attn_output = ulysses_attn(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def new_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
assert isinstance(
|
||||
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
||||
) or isinstance(
|
||||
self.self_attn,
|
||||
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
|
||||
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
|
||||
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def apply_ulysses_attn_monkey_patch_llama():
|
||||
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
|
||||
new_flash_attn_forward
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||
new_decoder_forward
|
||||
)
|
||||
@@ -1,44 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def extract_local(value, rank, world_size, device, dim=1):
|
||||
dimension_size = value.shape[dim]
|
||||
sub_seq_length = dimension_size // world_size
|
||||
|
||||
sub_seq_start = rank * sub_seq_length
|
||||
sub_seq_end = (rank + 1) * sub_seq_length
|
||||
local_value = value[:, sub_seq_start:sub_seq_end]
|
||||
|
||||
return local_value.to(device)
|
||||
|
||||
|
||||
def prepare_ulysses_attn_inputs(
|
||||
input_ids, position_ids, target_ids, rank, world_size, device
|
||||
):
|
||||
local_input_ids = extract_local(
|
||||
input_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
local_position_ids = extract_local(
|
||||
position_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
|
||||
if target_ids is not None:
|
||||
local_target_ids = extract_local(
|
||||
target_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
else:
|
||||
local_target_ids = None
|
||||
return {
|
||||
"local_input_ids": local_input_ids,
|
||||
"local_position_ids": local_position_ids,
|
||||
"local_target_ids": local_target_ids,
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
|
||||
"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
Tiny hit to performance, since we mask the movement via non blocking calls.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
|
||||
return output
|
||||
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, dY):
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||
torch.autograd.backward(output, dY)
|
||||
return (
|
||||
None,
|
||||
hidden_states.grad,
|
||||
) + (
|
||||
None,
|
||||
) * len(ctx.args)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
assert gradient_checkpointing_kwargs == None
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support gradient checkpointing."
|
||||
)
|
||||
|
||||
gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` method
|
||||
_is_using_old_format = (
|
||||
"value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
)
|
||||
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(
|
||||
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||
# the gradients to make sure the gradient flows.
|
||||
self.enable_input_require_grads()
|
||||
|
||||
|
||||
def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
|
||||
transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
|
||||
new_gradient_checkpointing_enable
|
||||
)
|
||||
@@ -1,114 +0,0 @@
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
|
||||
try:
|
||||
from yunchang import LongContextAttention, set_seq_parallel_pg
|
||||
|
||||
usp_attn = LongContextAttention(ring_impl_type="zigzag")
|
||||
except:
|
||||
print(
|
||||
"If you want to use the LongContextAttention class, please install the yunchang package."
|
||||
)
|
||||
usp_attn = None
|
||||
|
||||
|
||||
def new_flash_attn_forward(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=0.0,
|
||||
softmax_scale=None,
|
||||
use_sliding_windows=False,
|
||||
):
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
assert attention_mask is None
|
||||
assert causal is True
|
||||
assert use_sliding_windows is False
|
||||
attn_output = usp_attn(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def new_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
assert isinstance(
|
||||
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
||||
) or isinstance(
|
||||
self.self_attn,
|
||||
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
|
||||
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
|
||||
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def apply_usp_attn_monkey_patch_llama():
|
||||
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
|
||||
new_flash_attn_forward
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||
new_decoder_forward
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
import torch
|
||||
from yunchang import set_seq_parallel_pg
|
||||
from yunchang.comm import zigzag_extract_local
|
||||
|
||||
|
||||
def prepare_usp_attn_inputs(
|
||||
input_ids,
|
||||
position_ids,
|
||||
target_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
ring_degree,
|
||||
ulysses_degree,
|
||||
):
|
||||
f"""
|
||||
prepare input for USP attention
|
||||
|
||||
USP: A Unified Sequence Parallelism Approach for Long Context Generative AI
|
||||
https://arxiv.org/abs/2405.07719
|
||||
"""
|
||||
|
||||
set_seq_parallel_pg(ulysses_degree, ring_degree, rank, world_size)
|
||||
|
||||
local_input_ids = zigzag_extract_local(
|
||||
input_ids,
|
||||
rank,
|
||||
world_size,
|
||||
ring_degree,
|
||||
ulysses_degree,
|
||||
).to(device)
|
||||
|
||||
# truncate position_ids to the same size as input_ids
|
||||
position_ids = position_ids[:, : local_input_ids.shape[1]]
|
||||
|
||||
local_position_ids = zigzag_extract_local(
|
||||
position_ids,
|
||||
rank,
|
||||
world_size,
|
||||
ring_degree,
|
||||
ulysses_degree,
|
||||
).to(device)
|
||||
|
||||
if target_ids is not None:
|
||||
local_target_ids = zigzag_extract_local(
|
||||
target_ids,
|
||||
rank,
|
||||
world_size,
|
||||
ring_degree,
|
||||
ulysses_degree,
|
||||
).to(device)
|
||||
else:
|
||||
local_target_ids = None
|
||||
return {
|
||||
"local_input_ids": local_input_ids,
|
||||
"local_position_ids": local_position_ids,
|
||||
"local_target_ids": local_target_ids,
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention,
|
||||
MistralDecoderLayer,
|
||||
)
|
||||
|
||||
|
||||
def new_flash_attn_forward(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=0.0,
|
||||
softmax_scale=None,
|
||||
use_sliding_windows=False,
|
||||
):
|
||||
assert (
|
||||
self.config._attn_implementation == "flash_attention_2"
|
||||
), "Only Flash Attention is supported."
|
||||
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
assert attention_mask is None
|
||||
assert causal is True
|
||||
assert use_sliding_windows is False
|
||||
attn_output = zigzag_ring_flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def new_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
assert isinstance(self.self_attn, LlamaAttention) or isinstance(
|
||||
self.self_attn,
|
||||
MistralAttention,
|
||||
), "Llama and Mistral attention only are supported."
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def apply_zigzag_ring_attn_monkey_patch_llama():
|
||||
# LlamaAttention._flash_attention_forward = new_flash_attn_forward
|
||||
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
|
||||
LlamaDecoderLayer.forward = new_decoder_forward
|
||||
|
||||
|
||||
def apply_zigzag_ring_attn_monkey_patch_mistral():
|
||||
# MistralAttention._flash_attention_forward = new_flash_attn_forward
|
||||
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
|
||||
MistralDecoderLayer.forward = new_decoder_forward
|
||||
@@ -1,40 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def extract_local(value, rank, world_size, device, dim=1):
|
||||
value_chunks = value.chunk(2 * world_size, dim=dim)
|
||||
local_value = torch.cat(
|
||||
[value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim
|
||||
)
|
||||
return local_value.to(device)
|
||||
|
||||
|
||||
def prepare_zigzag_ring_attn_inputs(
|
||||
input_ids, position_ids, target_ids, rank, world_size, device
|
||||
):
|
||||
local_input_ids = extract_local(
|
||||
input_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
local_position_ids = extract_local(
|
||||
position_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
if target_ids is not None:
|
||||
local_target_ids = extract_local(
|
||||
target_ids,
|
||||
rank,
|
||||
world_size,
|
||||
device,
|
||||
)
|
||||
else:
|
||||
local_target_ids = None
|
||||
return {
|
||||
"local_input_ids": local_input_ids,
|
||||
"local_position_ids": local_position_ids,
|
||||
"local_target_ids": local_target_ids,
|
||||
}
|
||||
@@ -549,16 +549,6 @@ class ModelLoader:
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_size > 1:
|
||||
# from axolotl.integrations.easy_context import (
|
||||
# apply_seq_parallel_monkey_patch,
|
||||
# )
|
||||
|
||||
# method = self.cfg.sequence_parallel_mode
|
||||
# model_type = self.cfg.model_config_type
|
||||
|
||||
# # Apply the monkey patch
|
||||
# apply_seq_parallel_monkey_patch(method, model_type)
|
||||
|
||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
|
||||
@@ -17,7 +17,6 @@ from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
@@ -357,17 +356,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if cfg.sequence_parallel_size > 1:
|
||||
train_dataset.map(
|
||||
prepare_seq_parallel_inputs,
|
||||
"dist_flash_attn",
|
||||
lambda batch: batch["input_ids"],
|
||||
lambda batch: batch["position_ids"],
|
||||
lambda batch: batch["target_ids"],
|
||||
accelerator.process_index,
|
||||
accelerator.num_processes,
|
||||
accelerator.device,
|
||||
)
|
||||
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
|
||||
Reference in New Issue
Block a user