Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GEMM] Undo small GRF autotune config for transpose A #3323

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TRANSPOSE_A = os.getenv('TRANSPOSE_A', '0') == '1'
TRANSPOSE_B = os.getenv('TRANSPOSE_B', '0') == '1'
use_xetla = not (TRANSPOSE_A or TRANSPOSE_B)
SMALL_GRF = os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' and not TRANSPOSE_A


@triton.autotune(
Expand All @@ -27,16 +28,18 @@
num_stages=s, num_warps=32) for s in [1, 2, 3]
] + [
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
num_stages=s, num_warps=w) for s in [2, 3, 4] for (m, w) in
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
num_stages=s, num_warps=w)
for s in [2, 3, 4]
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
] + [
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
num_stages=s, num_warps=32) for s in [2]
] + [
triton.Config({'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': m},
num_stages=s, num_warps=w) for s in [2, 3] for (m, w) in
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
num_stages=s, num_warps=w)
for s in [2, 3]
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
],
key=['M', 'N', 'K'],
)
Expand Down Expand Up @@ -92,8 +95,9 @@ def matmul_kernel_with_block_pointers(
num_stages=s, num_warps=32) for s in [2, 3]
] + [
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': m},
num_stages=s, num_warps=w) for s in [2] for (m, w) in
([('large', 32), ('small', 64)] if os.getenv('TRITON_INTEL_ADVANCED_PATH', '0') == '0' else [('large', 32)])
num_stages=s, num_warps=w)
for s in [2]
for (m, w) in ([('large', 32), ('small', 64)] if SMALL_GRF else [('large', 32)])
] + [
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 1024, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'},
Expand Down