diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 0a80677bb5..0cfc9b1a2a 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -19,6 +19,10 @@ on: - ELAPSED_TIME - UPSTREAM_PYTORCH_PROFILER default: UPSTREAM_PYTORCH_PROFILER + verify: + description: Verify the benchmark results + type: boolean + default: true run_name: description: Run name type: string @@ -46,6 +50,7 @@ permissions: read-all env: PYTHON_VERSION: "3.10" BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER' }} + VERIFY: ${{ (github.event_name == 'pull_request' || github.event_name == 'schedule' || inputs.verify) && '1' || '0' }} TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }} jobs: diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index b62432ad6a..a340915d9c 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -1,8 +1,6 @@ import os -from triton.testing import assert_close - -from .benchmark_testing import do_bench, perf_report, Benchmark, BENCHMARKING_METHOD +from .benchmark_testing import assert_close, do_bench, perf_report, Benchmark, BENCHMARKING_METHOD if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": os.environ["INJECT_PYTORCH"] = "True" diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 4e94416a4d..cb94b538bc 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -2,9 +2,10 @@ import itertools import os -from triton.testing import Benchmark +from triton.testing import assert_close as triton_assert_close, Benchmark BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER") +VERIFY = os.getenv("VERIFY", "1") == "1" def synchronize(): @@ -161,6 +162,11 @@ def extract_kernels(funcs): raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") +def assert_close(x_fn, y_fn, atol=None, rtol=None, err_msg=""): + if VERIFY: + triton_assert_close(x_fn(), y_fn(), atol, rtol, err_msg) + + def perf_report(benchmarks): """ Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py index d0446a4984..9bb03656a1 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py @@ -576,9 +576,9 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider): torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True) if MODE == 'fwd': atol = 1e-1 if N_CTX == 16384 else 1e-2 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') + benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch') else: - benchmark_suit.assert_close(triton_o, torch_o, atol=1e-2, rtol=0, err_msg='triton to torch') + benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch') _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'xetla': diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 6782e92d6b..799bdb1b53 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -130,7 +130,7 @@ def benchmark(M, N, provider): out = torch.empty_like(x, device="xpu") triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) - benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") + benchmark_suit.assert_close(triton_fn, torch_fn, err_msg="triton to torch") _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10) elif provider == "torch-jit": @@ -143,7 +143,7 @@ def benchmark(M, N, provider): out = torch.empty_like(x, device="xpu") xetla_fn = lambda: func(x, out, 0) torch_fn = lambda: torch.softmax(x, axis=-1) - # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") + # benchmark_suit.assert_close(xetla_fn, torch_fn, err_msg="xetla to torch") _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 860595a06b..97dc321067 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -291,7 +291,7 @@ def benchmark(B, M, N, K, provider): triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B) torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'xetla': @@ -320,7 +320,7 @@ def xetla_func_with_acc_allocation(): xetla_fn = xetla_func_with_acc_allocation torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') + # benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 33551bf8b4..f4964e6d47 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -306,7 +306,7 @@ def benchmark(B, M, N, K, dtype, provider): if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048], [1, 512, 8192, 32768], [4, 32768, 4096, 128]]: # torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 0faeead793..65593f731e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -271,7 +271,7 @@ def benchmark(B, M, N, K, provider): triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 8b13827e0c..9b34558593 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -259,7 +259,7 @@ def benchmark(B, M, N, K, provider): triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index c4dd86d834..904d426556 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -154,7 +154,7 @@ def benchmark(M, N, K, provider): triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'xetla': @@ -167,7 +167,7 @@ def benchmark(M, N, K, provider): xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') + # benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 6ef40be902..29bd68698e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -275,7 +275,7 @@ def benchmark(M, N, K, provider): c = torch.zeros((M, N), device=a.device, dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') + benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=1e-2, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) elif provider == 'xetla': @@ -288,7 +288,7 @@ def benchmark(M, N, K, provider): xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') + # benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles) else: