From 8671ed5a0c1111833e63f3ba6b84427ce8cbbbf6 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Wed, 6 Dec 2023 20:26:31 +0000 Subject: [PATCH] Fix import --- src/axolotl/monkeypatch/cross_entropy.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/cross_entropy.py b/src/axolotl/monkeypatch/cross_entropy.py index 9826eebc9..55d0a8560 100644 --- a/src/axolotl/monkeypatch/cross_entropy.py +++ b/src/axolotl/monkeypatch/cross_entropy.py @@ -4,7 +4,21 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings + +MAX_FUSED_SIZE = 65536 + +def calculate_settings(n): + BLOCK_SIZE = triton.next_power_of_2(n) + # CUDA only supports 65536 - 2^16 threads per block + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ + f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") + num_warps = 4 + if BLOCK_SIZE >= 32768: num_warps = 32 + elif BLOCK_SIZE >= 8192: num_warps = 16 + elif BLOCK_SIZE >= 2048: num_warps = 8 + return BLOCK_SIZE, num_warps +pass @triton.jit def _cross_entropy_forward(logits_ptr, logits_row_stride,