reduce autotune search space (#3525) [skip ci]
* reduce autotune search space * consistent docstrings
This commit is contained in:
@@ -363,7 +363,7 @@ def _scatter2scatter_lora_configs():
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128}
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_N: {32, 64}
|
||||
BLOCK_K: {32, 64, 128}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
@@ -371,7 +371,7 @@ def _scatter2scatter_lora_configs():
|
||||
configs = []
|
||||
for block_m, block_n, block_k, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
[32, 64], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
@@ -943,16 +943,16 @@ def _scatter2scatter_lora_dX_configs():
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128} (token tile)
|
||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
||||
BLOCK_K: {32, 64, 128} (output tile)
|
||||
BLOCK_N: {32, 64} (reduction tile)
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
"""
|
||||
configs = []
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
||||
[32, 64, 128], # BLOCK_K (output dimension)
|
||||
[32, 64], # BLOCK_N (reduction dimension)
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
@@ -1278,9 +1278,9 @@ def _group_bwd_lora_configs():
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
|
||||
BLOCK_K: {32, 64, 128, 256}
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_M: {32, 64, 128} (token-loop tile)
|
||||
BLOCK_K: {32, 64, 128}
|
||||
BLOCK_N: {32, 64}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
|
||||
@@ -1289,9 +1289,9 @@ def _group_bwd_lora_configs():
|
||||
"""
|
||||
configs = []
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128, 256], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_K
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[32, 64], # BLOCK_N
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user