Fix import
This commit is contained in:
@@ -4,7 +4,21 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from .utils import calculate_settings
|
||||
|
||||
MAX_FUSED_SIZE = 65536
|
||||
|
||||
def calculate_settings(n):
|
||||
BLOCK_SIZE = triton.next_power_of_2(n)
|
||||
# CUDA only supports 65536 - 2^16 threads per block
|
||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||
return BLOCK_SIZE, num_warps
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||
|
||||
Reference in New Issue
Block a user