From 62e87e8fd35a16341ac6fac4b8b41f327c5956d1 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 15 Jan 2025 19:34:30 +0100 Subject: [PATCH] Enable `UPSTREAM_PYTORCH_PROFILER` method for backward Flash Attention (#3169) Signed-off-by: Anatoly Myachev --- .github/workflows/triton-benchmarks.yml | 2 +- .../benchmark_testing.py | 6 ++ .../flash_attention_benchmark.py | 83 ++++++++++--------- scripts/test-triton.sh | 3 +- 4 files changed, 54 insertions(+), 40 deletions(-) diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 0cfc9b1a2a..7d31283979 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -254,7 +254,7 @@ jobs: run: | cd benchmarks/triton_kernels_benchmark FA_KERNEL_MODE="bwd" \ - BENCHMARKING_METHOD="ELAPSED_TIME" python flash_attention_benchmark.py --reports $REPORTS + python flash_attention_benchmark.py --reports $REPORTS mv $REPORTS/attn-performance.csv $REPORTS/attn-bwd-performance.csv source ../../scripts/capture-hw-details.sh diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index cb94b538bc..a90e3d3286 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -148,6 +148,12 @@ def extract_kernels(funcs): return kernels kernels = [extract_kernels(func.cpu_children) for func in functions] + # For example, for backward FA, kernels can be empty for one of the threads. + # Keep in mind that `backward` function is launched in another thread and + # requires the use of `record_function` function additionally in its thread + # for correct registration of kernels. + # For details: https://github.com/pytorch/pytorch/issues/144778 + kernels = [kernel for kernel in kernels if kernel != []] assert len(kernels) == n_repeat, "the profiling number not match" # Make the time to the milliseconds. times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py index 9bb03656a1..ef40c3b507 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py @@ -1,5 +1,8 @@ import os +import contextlib + import torch +from torch.profiler import record_function import triton import triton.language as tl @@ -476,43 +479,49 @@ def forward(ctx, q, k, v, causal, sm_scale): @staticmethod def backward(ctx, do): - q, k, v, o, M = ctx.saved_tensors - assert do.is_contiguous() - assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - BATCH, N_HEAD, N_CTX = q.shape[:3] - PRE_BLOCK = 128 - NUM_WARPS, NUM_STAGES = 4, 5 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) - arg_k = k - arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - PRE_BLOCK = 128 - assert N_CTX % PRE_BLOCK == 0 - pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) - delta = torch.empty_like(M) - _attn_bwd_preprocess[pre_grid]( - o, do, # - delta, # - BATCH, N_HEAD, N_CTX, # - BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # - ) - grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) - _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # - M, delta, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - N_HEAD, N_CTX, # - BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # - BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # - HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # - num_stages=NUM_STAGES # - ) + # FIXME: There is no certainty as to how much such behavior is expected. + # Consider removing `record_function` call from here once + # https://github.com/pytorch/pytorch/issues/144778 has more details. + with record_function( + '__profile_kernel_of_func_bwd_fa' + ) if benchmark_suit.BENCHMARKING_METHOD == 'UPSTREAM_PYTORCH_PROFILER' else contextlib.nullcontext(): + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 5 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + HEAD_DIM=ctx.HEAD_DIM, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # + ) return dq, dk, dv, None, None diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index e08a34f164..2dfe387877 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -299,8 +299,7 @@ run_benchmark_attention() { echo "Backward - Default path:" FA_KERNEL_MODE="bwd" \ - BENCHMARKING_METHOD="ELAPSED_TIME" \ - python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py + python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py } run_benchmarks() {