From 2c05847a5fabf91fea7994abd73aaec775442b8f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Mar 2026 18:30:15 -0400 Subject: [PATCH] reduce autotune search space (#3525) [skip ci] * reduce autotune search space * consistent docstrings --- .../libs/scattermoe_lora/kernels/lora_ops.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 16f6da73b..e8d4309f9 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -363,7 +363,7 @@ def _scatter2scatter_lora_configs(): Search space: BLOCK_M: {32, 64, 128} - BLOCK_N: {32, 64, 128, 256} + BLOCK_N: {32, 64} BLOCK_K: {32, 64, 128} num_warps: {4, 8} num_stages: {3, 4, 5} @@ -371,7 +371,7 @@ def _scatter2scatter_lora_configs(): configs = [] for block_m, block_n, block_k, warps, stages in product( [32, 64, 128], # BLOCK_M - [32, 64, 128, 256], # BLOCK_N + [32, 64], # BLOCK_N [32, 64, 128], # BLOCK_K [4, 8], # num_warps [3, 4, 5], # num_stages @@ -943,16 +943,16 @@ def _scatter2scatter_lora_dX_configs(): Search space: BLOCK_M: {32, 64, 128} (token tile) - BLOCK_K: {32, 64, 128, 256} (output tile) - BLOCK_N: {32, 64, 128, 256} (reduction tile) + BLOCK_K: {32, 64, 128} (output tile) + BLOCK_N: {32, 64} (reduction tile) num_warps: {4, 8} num_stages: {3, 4, 5} """ configs = [] for block_m, block_k, block_n, warps, stages in product( [32, 64, 128], # BLOCK_M - [32, 64, 128, 256], # BLOCK_K (output dimension) - [32, 64, 128, 256], # BLOCK_N (reduction dimension) + [32, 64, 128], # BLOCK_K (output dimension) + [32, 64], # BLOCK_N (reduction dimension) [4, 8], # num_warps [3, 4, 5], # num_stages ): @@ -1278,9 +1278,9 @@ def _group_bwd_lora_configs(): support GPUs with limited shared memory (e.g. ~99KB on some GPUs). Search space: - BLOCK_M: {32, 64, 128, 256} (token-loop tile) - BLOCK_K: {32, 64, 128, 256} - BLOCK_N: {32, 64, 128, 256} + BLOCK_M: {32, 64, 128} (token-loop tile) + BLOCK_K: {32, 64, 128} + BLOCK_N: {32, 64} num_warps: {4, 8} num_stages: {3, 4, 5} @@ -1289,9 +1289,9 @@ def _group_bwd_lora_configs(): """ configs = [] for block_m, block_k, block_n, warps, stages in product( - [32, 64, 128, 256], # BLOCK_M - [32, 64, 128, 256], # BLOCK_K - [32, 64, 128, 256], # BLOCK_N + [32, 64, 128], # BLOCK_M + [32, 64, 128], # BLOCK_K + [32, 64], # BLOCK_N [4, 8], # num_warps [3, 4, 5], # num_stages ):