Skip to content

Commit

Permalink
[BENCHMARKS][CI] Add ability to disable verification (#3166)
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang authored Jan 15, 2025
1 parent 69674be commit 35dcc03
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 17 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ on:
- ELAPSED_TIME
- UPSTREAM_PYTORCH_PROFILER
default: UPSTREAM_PYTORCH_PROFILER
verify:
description: Verify the benchmark results
type: boolean
default: true
run_name:
description: Run name
type: string
Expand Down Expand Up @@ -46,6 +50,7 @@ permissions: read-all
env:
PYTHON_VERSION: "3.10"
BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER' }}
VERIFY: ${{ (github.event_name == 'pull_request' || github.event_name == 'schedule' || inputs.verify) && '1' || '0' }}
TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }}

jobs:
Expand Down
4 changes: 1 addition & 3 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os

from triton.testing import assert_close

from .benchmark_testing import do_bench, perf_report, Benchmark, BENCHMARKING_METHOD
from .benchmark_testing import assert_close, do_bench, perf_report, Benchmark, BENCHMARKING_METHOD

if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
os.environ["INJECT_PYTORCH"] = "True"
Expand Down
8 changes: 7 additions & 1 deletion benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import itertools
import os

from triton.testing import Benchmark
from triton.testing import assert_close as triton_assert_close, Benchmark

BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")
VERIFY = os.getenv("VERIFY", "1") == "1"


def synchronize():
Expand Down Expand Up @@ -161,6 +162,11 @@ def extract_kernels(funcs):
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")


def assert_close(x_fn, y_fn, atol=None, rtol=None, err_msg=""):
if VERIFY:
triton_assert_close(x_fn(), y_fn(), atol, rtol, err_msg)


def perf_report(benchmarks):
"""
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,9 +576,9 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
if MODE == 'fwd':
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
else:
benchmark_suit.assert_close(triton_o, torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)

elif provider == 'xetla':
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def benchmark(M, N, provider):
out = torch.empty_like(x, device="xpu")
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
benchmark_suit.assert_close(triton_fn, torch_fn, err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10)

elif provider == "torch-jit":
Expand All @@ -143,7 +143,7 @@ def benchmark(M, N, provider):
out = torch.empty_like(x, device="xpu")
xetla_fn = lambda: func(x, out, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
# benchmark_suit.assert_close(xetla_fn, torch_fn, err_msg="xetla to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10)

else:
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def benchmark(B, M, N, K, provider):
triton_fn = lambda: matmul(a, b, c, transpose_a=TRANSPOSE_A, transpose_b=TRANSPOSE_B)
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'xetla':
Expand Down Expand Up @@ -320,7 +320,7 @@ def xetla_func_with_acc_allocation():
xetla_fn = xetla_func_with_acc_allocation
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
# benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def benchmark(B, M, N, K, dtype, provider):
if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048],
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]:
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def benchmark(B, M, N, K, provider):
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def benchmark(B, M, N, K, provider):
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def benchmark(M, N, K, provider):
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'xetla':
Expand All @@ -167,7 +167,7 @@ def benchmark(M, N, K, provider):
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
# benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def benchmark(M, N, K, provider):
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'xetla':
Expand All @@ -288,7 +288,7 @@ def benchmark(M, N, K, provider):
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
# benchmark_suit.assert_close(xetla_fn, torch_fn, atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles)
else:
Expand Down

0 comments on commit 35dcc03

Please sign in to comment.