optimize moe + lora
This commit is contained in:
@@ -1330,6 +1330,224 @@ def _group_bwd_lora(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _group_bwd_split_configs():
|
||||||
|
"""Autotune configs for split dA/dB kernels."""
|
||||||
|
configs = []
|
||||||
|
for block_m, block_dim, warps, stages in product(
|
||||||
|
[32, 64, 128], # BLOCK_M (token tile)
|
||||||
|
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
|
||||||
|
[4, 8], # num_warps
|
||||||
|
[3, 4, 5], # num_stages
|
||||||
|
):
|
||||||
|
configs.append(
|
||||||
|
triton.Config(
|
||||||
|
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
|
||||||
|
num_stages=stages,
|
||||||
|
num_warps=warps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_split_configs(configs, named_args, **kwargs):
|
||||||
|
"""Prune split kernel configs based on SMEM capacity."""
|
||||||
|
smem_cap = _get_smem_capacity()
|
||||||
|
block_r = named_args.get("BLOCK_R", 64)
|
||||||
|
inner_dim = named_args.get("INNER_DIM", 2048)
|
||||||
|
|
||||||
|
# Fixed inner tile for reduction dimension
|
||||||
|
BLOCK_INNER = 64
|
||||||
|
|
||||||
|
pruned = []
|
||||||
|
for config in configs:
|
||||||
|
block_m = config.kwargs["BLOCK_M"]
|
||||||
|
block_dim = config.kwargs["BLOCK_DIM"]
|
||||||
|
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
|
||||||
|
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
|
||||||
|
# LoRA weights held in registers: [INNER, R] or [R, DIM]
|
||||||
|
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
|
||||||
|
if smem <= smem_cap - _SMEM_SLACK:
|
||||||
|
pruned.append(config)
|
||||||
|
|
||||||
|
if pruned:
|
||||||
|
return pruned
|
||||||
|
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
|
||||||
|
return [configs[0]]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=_group_bwd_split_configs(),
|
||||||
|
key=["M", "K", "N"],
|
||||||
|
prune_configs_by={"early_config_prune": _prune_split_configs},
|
||||||
|
)
|
||||||
|
@triton.heuristics({
|
||||||
|
"NO_DIM_MASK": lambda args: (
|
||||||
|
(args["K"] % args["BLOCK_DIM"]) == 0
|
||||||
|
if args["COMPUTE_DA"]
|
||||||
|
else (args["N"] % args["BLOCK_DIM"]) == 0
|
||||||
|
),
|
||||||
|
})
|
||||||
|
@triton.jit
|
||||||
|
def _group_bwd_lora_split(
|
||||||
|
# Data tensors (DY and X are always present)
|
||||||
|
DY_ptr, stride_dym, stride_dyn,
|
||||||
|
X_ptr, stride_xm, stride_xk,
|
||||||
|
# LoRA weight for the inner reduction (B for dA, A for dB)
|
||||||
|
LW_ptr, stride_lw0, stride_lw1,
|
||||||
|
# Output gradient tensor (dA or dB)
|
||||||
|
OUT_ptr, stride_out0, stride_out1,
|
||||||
|
# Expert offsets
|
||||||
|
expert_offsets_ptr,
|
||||||
|
# Dimensions
|
||||||
|
M, K: tl.constexpr, N: tl.constexpr,
|
||||||
|
ACTUAL_R: tl.constexpr, BLOCK_R: tl.constexpr,
|
||||||
|
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
|
||||||
|
scaling,
|
||||||
|
# Mode flag
|
||||||
|
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
|
||||||
|
# Tile sizes
|
||||||
|
BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr,
|
||||||
|
allow_tf32: tl.constexpr,
|
||||||
|
NO_DIM_MASK: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Unified split kernel for LoRA gradient computation.
|
||||||
|
|
||||||
|
When COMPUTE_DA=True:
|
||||||
|
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
|
||||||
|
Grid: (E, cdiv(K, BLOCK_DIM))
|
||||||
|
- outer_ptr/stride = X (read [M, K_block])
|
||||||
|
- inner reduction over N using DY and B
|
||||||
|
- output shape [BLOCK_R, BLOCK_DIM]
|
||||||
|
|
||||||
|
When COMPUTE_DA=False:
|
||||||
|
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
|
||||||
|
Grid: (E, cdiv(N, BLOCK_DIM))
|
||||||
|
- outer_ptr/stride = DY (read [M, N_block])
|
||||||
|
- inner reduction over K using X and A
|
||||||
|
- output shape [BLOCK_DIM, BLOCK_R]
|
||||||
|
|
||||||
|
No atomic adds — each (E, dim_block) pair is written by exactly one block.
|
||||||
|
"""
|
||||||
|
E_idx = tl.program_id(0)
|
||||||
|
dim_block_id = tl.program_id(1)
|
||||||
|
|
||||||
|
if E_idx == 0:
|
||||||
|
start_idx = 0
|
||||||
|
else:
|
||||||
|
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||||
|
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||||
|
num_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
# Output dimension tile (K for dA, N for dB)
|
||||||
|
if COMPUTE_DA:
|
||||||
|
OUT_DIM: tl.constexpr = K
|
||||||
|
else:
|
||||||
|
OUT_DIM: tl.constexpr = N
|
||||||
|
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||||
|
dim_mask = dim_block < OUT_DIM
|
||||||
|
R_block = tl.arange(0, BLOCK_R)
|
||||||
|
R_mask = R_block < ACTUAL_R
|
||||||
|
lora_offset = E_idx * ACTUAL_R
|
||||||
|
|
||||||
|
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
|
||||||
|
if COMPUTE_DA:
|
||||||
|
out_blk_ptrs = (
|
||||||
|
OUT_ptr
|
||||||
|
+ (lora_offset + R_block)[:, None] * stride_out0
|
||||||
|
+ dim_block[None, :] * stride_out1
|
||||||
|
)
|
||||||
|
out_mask = R_mask[:, None] & dim_mask[None, :]
|
||||||
|
else:
|
||||||
|
out_blk_ptrs = (
|
||||||
|
OUT_ptr
|
||||||
|
+ dim_block[:, None] * stride_out0
|
||||||
|
+ (lora_offset + R_block)[None, :] * stride_out1
|
||||||
|
)
|
||||||
|
out_mask = dim_mask[:, None] & R_mask[None, :]
|
||||||
|
|
||||||
|
if num_tokens > 0:
|
||||||
|
M_block = tl.arange(0, BLOCK_M)
|
||||||
|
INPUT_DTYPE = X_ptr.dtype.element_ty
|
||||||
|
BLOCK_INNER: tl.constexpr = 64
|
||||||
|
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
|
||||||
|
|
||||||
|
if COMPUTE_DA:
|
||||||
|
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
|
||||||
|
else:
|
||||||
|
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
|
||||||
|
|
||||||
|
M_iters = tl.cdiv(num_tokens, BLOCK_M)
|
||||||
|
for i in range(M_iters):
|
||||||
|
M_idx = start_idx + i * BLOCK_M + M_block
|
||||||
|
M_mask = M_idx < end_idx
|
||||||
|
|
||||||
|
if COMPUTE_DA:
|
||||||
|
# Load X[M, K_block] (the "outer" tensor for dA)
|
||||||
|
outer = tl.load(
|
||||||
|
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
|
||||||
|
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
|
||||||
|
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
|
||||||
|
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||||
|
inner_range = tl.arange(0, BLOCK_INNER)
|
||||||
|
for j in range(inner_iters):
|
||||||
|
inn_off = j * BLOCK_INNER + inner_range
|
||||||
|
inn_mask = inn_off < N
|
||||||
|
|
||||||
|
dy_tile = tl.load(
|
||||||
|
DY_ptr + M_idx[:, None] * stride_dym + inn_off[None, :] * stride_dyn,
|
||||||
|
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
|
||||||
|
lw_tile = tl.load(
|
||||||
|
LW_ptr + inn_off[:, None] * stride_lw0 + (lora_offset + R_block)[None, :] * stride_lw1,
|
||||||
|
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
|
||||||
|
|
||||||
|
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
|
||||||
|
acc += tl.dot(tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32)
|
||||||
|
else:
|
||||||
|
# Load DY[M, N_block] (the "outer" tensor for dB)
|
||||||
|
outer = tl.load(
|
||||||
|
DY_ptr + M_idx[:, None] * stride_dym + dim_block[None, :] * stride_dyn,
|
||||||
|
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
|
||||||
|
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
|
||||||
|
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||||
|
inner_range = tl.arange(0, BLOCK_INNER)
|
||||||
|
for j in range(inner_iters):
|
||||||
|
inn_off = j * BLOCK_INNER + inner_range
|
||||||
|
inn_mask = inn_off < K
|
||||||
|
|
||||||
|
x_tile = tl.load(
|
||||||
|
X_ptr + M_idx[:, None] * stride_xm + inn_off[None, :] * stride_xk,
|
||||||
|
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
|
||||||
|
# We want A[e]^T: [K, R], so load as [K_inner, R]
|
||||||
|
lw_tile = tl.load(
|
||||||
|
LW_ptr + (lora_offset + R_block)[None, :] * stride_lw0 + inn_off[:, None] * stride_lw1,
|
||||||
|
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
|
||||||
|
|
||||||
|
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
|
||||||
|
acc += tl.dot(tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32)
|
||||||
|
|
||||||
|
tl.store(out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask)
|
||||||
|
else:
|
||||||
|
# Zero out this expert's slice — needed because output uses empty_like
|
||||||
|
if COMPUTE_DA:
|
||||||
|
tl.store(out_blk_ptrs, tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
|
||||||
|
else:
|
||||||
|
tl.store(out_blk_ptrs, tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
|
||||||
|
|
||||||
|
|
||||||
def group_bwd_lora(
|
def group_bwd_lora(
|
||||||
DY: torch.Tensor,
|
DY: torch.Tensor,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
@@ -1344,6 +1562,9 @@ def group_bwd_lora(
|
|||||||
"""
|
"""
|
||||||
Compute LoRA gradients for A and B on expert-grouped data.
|
Compute LoRA gradients for A and B on expert-grouped data.
|
||||||
|
|
||||||
|
Uses split dA/dB kernels that eliminate atomic adds by giving each
|
||||||
|
(expert, output_block) pair its own thread block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
|
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
|
||||||
X: Input [M_total, K] (grouped by expert)
|
X: Input [M_total, K] (grouped by expert)
|
||||||
@@ -1361,46 +1582,45 @@ def group_bwd_lora(
|
|||||||
K = X.size(1)
|
K = X.size(1)
|
||||||
N = DY.size(1)
|
N = DY.size(1)
|
||||||
|
|
||||||
# Zero-init for atomic accumulation
|
# No zero-init needed: the split kernels write zeros for experts with
|
||||||
dA = torch.zeros_like(lora_A)
|
# zero routed tokens directly in the kernel (else branch).
|
||||||
dB = torch.zeros_like(lora_B)
|
dA = torch.empty_like(lora_A)
|
||||||
|
dB = torch.empty_like(lora_B)
|
||||||
|
|
||||||
BLOCK_R = _block_r_for_rank(R)
|
BLOCK_R = _block_r_for_rank(R)
|
||||||
|
|
||||||
def grid(META):
|
def grid_dA(META):
|
||||||
return (
|
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
|
||||||
E * triton.cdiv(K, META["BLOCK_K"]),
|
|
||||||
triton.cdiv(N, META["BLOCK_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
_group_bwd_lora[grid](
|
_group_bwd_lora_split[grid_dA](
|
||||||
DY,
|
DY, DY.stride(0), DY.stride(1),
|
||||||
DY.stride(0),
|
X, X.stride(0), X.stride(1),
|
||||||
DY.stride(1),
|
lora_B, lora_B.stride(0), lora_B.stride(1),
|
||||||
X,
|
dA, dA.stride(0), dA.stride(1),
|
||||||
X.stride(0),
|
|
||||||
X.stride(1),
|
|
||||||
lora_A,
|
|
||||||
lora_A.stride(0),
|
|
||||||
lora_A.stride(1),
|
|
||||||
lora_B,
|
|
||||||
lora_B.stride(0),
|
|
||||||
lora_B.stride(1),
|
|
||||||
dA,
|
|
||||||
dA.stride(0),
|
|
||||||
dA.stride(1),
|
|
||||||
dB,
|
|
||||||
dB.stride(0),
|
|
||||||
dB.stride(1),
|
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
M=DY.size(0),
|
M=DY.size(0), K=K, N=N,
|
||||||
K=K,
|
ACTUAL_R=R, BLOCK_R=BLOCK_R,
|
||||||
N=N,
|
INNER_DIM=N,
|
||||||
ACTUAL_R=R, # True LoRA rank
|
|
||||||
BLOCK_R=BLOCK_R, # Padded tile size
|
|
||||||
scaling=scaling,
|
scaling=scaling,
|
||||||
ACC_TYPE=tl.float32,
|
COMPUTE_DA=True,
|
||||||
allow_tf32=ALLOW_TF32,
|
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
|
||||||
|
)
|
||||||
|
|
||||||
|
def grid_dB(META):
|
||||||
|
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
|
||||||
|
|
||||||
|
_group_bwd_lora_split[grid_dB](
|
||||||
|
DY, DY.stride(0), DY.stride(1),
|
||||||
|
X, X.stride(0), X.stride(1),
|
||||||
|
lora_A, lora_A.stride(0), lora_A.stride(1),
|
||||||
|
dB, dB.stride(0), dB.stride(1),
|
||||||
|
expert_offsets,
|
||||||
|
M=DY.size(0), K=K, N=N,
|
||||||
|
ACTUAL_R=R, BLOCK_R=BLOCK_R,
|
||||||
|
INNER_DIM=K,
|
||||||
|
scaling=scaling,
|
||||||
|
COMPUTE_DA=False,
|
||||||
|
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
|
||||||
)
|
)
|
||||||
|
|
||||||
return dA, dB
|
return dA, dB
|
||||||
|
|||||||
@@ -61,7 +61,13 @@ class KernelsPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.kernels.KernelsArgs"
|
return "axolotl.integrations.kernels.KernelsArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
|
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
|
||||||
|
|
||||||
|
# Prefer text backbone type for VLMs, but fall back to base type
|
||||||
|
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||||
|
if moe_model_type not in SPARSE_MOE_BLOCK and cfg.model_config_type in SPARSE_MOE_BLOCK:
|
||||||
|
moe_model_type = cfg.model_config_type
|
||||||
|
|
||||||
if cfg.use_scattermoe:
|
if cfg.use_scattermoe:
|
||||||
self._register_kernels()
|
self._register_kernels()
|
||||||
|
|||||||
@@ -505,6 +505,20 @@ class ModelLoader:
|
|||||||
elif not is_ds_zero3:
|
elif not is_ds_zero3:
|
||||||
self.model_kwargs["device_map"] = device_map
|
self.model_kwargs["device_map"] = device_map
|
||||||
|
|
||||||
|
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
|
||||||
|
# so the actual VRAM usage is much less than bf16 estimates.
|
||||||
|
# When device_map is "auto", accelerate's infer_auto_device_map computes
|
||||||
|
# the device map at bf16 size (before quantization), causing it to offload
|
||||||
|
# layers to CPU, which BnB then rejects. Force single-GPU placement to
|
||||||
|
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
|
||||||
|
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
|
||||||
|
"auto",
|
||||||
|
None,
|
||||||
|
):
|
||||||
|
self.model_kwargs["device_map"] = {
|
||||||
|
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
}
|
||||||
|
|
||||||
cur_device = get_device_type()
|
cur_device = get_device_type()
|
||||||
if "mps" in str(cur_device):
|
if "mps" in str(cur_device):
|
||||||
self.model_kwargs["device_map"] = "mps:0"
|
self.model_kwargs["device_map"] = "mps:0"
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ from transformers import (
|
|||||||
class PytorchProfilerCallback(TrainerCallback):
|
class PytorchProfilerCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||||
|
|
||||||
|
Also runs torch.profiler to produce a Chrome trace for timing analysis.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
||||||
@@ -26,9 +28,12 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
if profiler_steps_start == 0:
|
if profiler_steps_start == 0:
|
||||||
# start recording memory allocations before everything is allocated, because if we start
|
# start recording memory allocations before everything is allocated, because if we start
|
||||||
# at the beginning of step 0, we won't have any memory allocations in the traces
|
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||||
torch.cuda.memory._record_memory_history(enabled="all")
|
torch.cuda.memory._record_memory_history(
|
||||||
|
enabled="all", stacks="all"
|
||||||
|
)
|
||||||
profiler_steps_start = -1
|
profiler_steps_start = -1
|
||||||
self.profiler_steps_start = profiler_steps_start
|
self.profiler_steps_start = profiler_steps_start
|
||||||
|
self._profiler = None
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self,
|
self,
|
||||||
@@ -38,7 +43,22 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if state.global_step == self.profiler_steps_start:
|
if state.global_step == self.profiler_steps_start:
|
||||||
torch.cuda.memory._record_memory_history(enabled="all")
|
torch.cuda.memory._record_memory_history(
|
||||||
|
enabled="all", stacks="all"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start torch.profiler on the first profiled step
|
||||||
|
if state.global_step == max(self.profiler_steps_start, 0):
|
||||||
|
self._profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
self._profiler.__enter__()
|
||||||
|
|
||||||
def on_step_end(
|
def on_step_end(
|
||||||
self,
|
self,
|
||||||
@@ -55,6 +75,13 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
# tell CUDA to stop recording memory allocations now
|
# tell CUDA to stop recording memory allocations now
|
||||||
torch.cuda.memory._record_memory_history(enabled=None)
|
torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
|
||||||
|
# Stop and export torch.profiler trace
|
||||||
|
if self._profiler is not None:
|
||||||
|
self._profiler.__exit__(None, None, None)
|
||||||
|
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||||
|
self._profiler.export_chrome_trace(str(trace_path))
|
||||||
|
self._profiler = None
|
||||||
|
|
||||||
def on_train_end(
|
def on_train_end(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments,
|
||||||
@@ -73,3 +100,9 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
|
|
||||||
# tell CUDA to stop recording memory allocations now
|
# tell CUDA to stop recording memory allocations now
|
||||||
torch.cuda.memory._record_memory_history(enabled=None)
|
torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
|
||||||
|
if self._profiler is not None:
|
||||||
|
self._profiler.__exit__(None, None, None)
|
||||||
|
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||||
|
self._profiler.export_chrome_trace(str(trace_path))
|
||||||
|
self._profiler = None
|
||||||
|
|||||||
Reference in New Issue
Block a user