reduce autotune search space (#3525) [skip ci]
* reduce autotune search space * consistent docstrings
This commit is contained in:
@@ -363,7 +363,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128}
|
BLOCK_M: {32, 64, 128}
|
||||||
BLOCK_N: {32, 64, 128, 256}
|
BLOCK_N: {32, 64}
|
||||||
BLOCK_K: {32, 64, 128}
|
BLOCK_K: {32, 64, 128}
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
@@ -371,7 +371,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
configs = []
|
configs = []
|
||||||
for block_m, block_n, block_k, warps, stages in product(
|
for block_m, block_n, block_k, warps, stages in product(
|
||||||
[32, 64, 128], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_N
|
[32, 64], # BLOCK_N
|
||||||
[32, 64, 128], # BLOCK_K
|
[32, 64, 128], # BLOCK_K
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
@@ -943,16 +943,16 @@ def _scatter2scatter_lora_dX_configs():
|
|||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128} (token tile)
|
BLOCK_M: {32, 64, 128} (token tile)
|
||||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
BLOCK_K: {32, 64, 128} (output tile)
|
||||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
BLOCK_N: {32, 64} (reduction tile)
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_m, block_k, block_n, warps, stages in product(
|
for block_m, block_k, block_n, warps, stages in product(
|
||||||
[32, 64, 128], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
[32, 64, 128], # BLOCK_K (output dimension)
|
||||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
[32, 64], # BLOCK_N (reduction dimension)
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
):
|
):
|
||||||
@@ -1278,9 +1278,9 @@ def _group_bwd_lora_configs():
|
|||||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
|
BLOCK_M: {32, 64, 128} (token-loop tile)
|
||||||
BLOCK_K: {32, 64, 128, 256}
|
BLOCK_K: {32, 64, 128}
|
||||||
BLOCK_N: {32, 64, 128, 256}
|
BLOCK_N: {32, 64}
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
|
|
||||||
@@ -1289,9 +1289,9 @@ def _group_bwd_lora_configs():
|
|||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_m, block_k, block_n, warps, stages in product(
|
for block_m, block_k, block_n, warps, stages in product(
|
||||||
[32, 64, 128, 256], # BLOCK_M
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_K
|
[32, 64, 128], # BLOCK_K
|
||||||
[32, 64, 128, 256], # BLOCK_N
|
[32, 64], # BLOCK_N
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
[3, 4, 5], # num_stages
|
[3, 4, 5], # num_stages
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user