Fix import
This commit is contained in:
@@ -4,7 +4,21 @@
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import torch
|
import torch
|
||||||
from .utils import calculate_settings
|
|
||||||
|
MAX_FUSED_SIZE = 65536
|
||||||
|
|
||||||
|
def calculate_settings(n):
|
||||||
|
BLOCK_SIZE = triton.next_power_of_2(n)
|
||||||
|
# CUDA only supports 65536 - 2^16 threads per block
|
||||||
|
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||||
|
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||||
|
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||||
|
num_warps = 4
|
||||||
|
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||||
|
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||||
|
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||||
|
return BLOCK_SIZE, num_warps
|
||||||
|
pass
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||||
|
|||||||
Reference in New Issue
Block a user