Skip to content

Commit

Permalink
Enable UPSTREAM_PYTORCH_PROFILER method for backward Flash Attention (
Browse files Browse the repository at this point in the history
#3169)

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Jan 15, 2025
1 parent 35dcc03 commit 62e87e8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ jobs:
run: |
cd benchmarks/triton_kernels_benchmark
FA_KERNEL_MODE="bwd" \
BENCHMARKING_METHOD="ELAPSED_TIME" python flash_attention_benchmark.py --reports $REPORTS
python flash_attention_benchmark.py --reports $REPORTS
mv $REPORTS/attn-performance.csv $REPORTS/attn-bwd-performance.csv
source ../../scripts/capture-hw-details.sh
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def extract_kernels(funcs):
return kernels

kernels = [extract_kernels(func.cpu_children) for func in functions]
# For example, for backward FA, kernels can be empty for one of the threads.
# Keep in mind that `backward` function is launched in another thread and
# requires the use of `record_function` function additionally in its thread
# for correct registration of kernels.
# For details: https://github.com/pytorch/pytorch/issues/144778
kernels = [kernel for kernel in kernels if kernel != []]
assert len(kernels) == n_repeat, "the profiling number not match"
# Make the time to the milliseconds.
times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float)
Expand Down
83 changes: 46 additions & 37 deletions benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import contextlib

import torch
from torch.profiler import record_function
import triton
import triton.language as tl

Expand Down Expand Up @@ -476,43 +479,49 @@ def forward(ctx, q, k, v, causal, sm_scale):

@staticmethod
def backward(ctx, do):
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
PRE_BLOCK = 128
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
)
# FIXME: There is no certainty as to how much such behavior is expected.
# Consider removing `record_function` call from here once
# https://github.com/pytorch/pytorch/issues/144778 has more details.
with record_function(
'__profile_kernel_of_func_bwd_fa'
) if benchmark_suit.BENCHMARKING_METHOD == 'UPSTREAM_PYTORCH_PROFILER' else contextlib.nullcontext():
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
PRE_BLOCK = 128
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
)

return dq, dk, dv, None, None

Expand Down
3 changes: 1 addition & 2 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ run_benchmark_attention() {

echo "Backward - Default path:"
FA_KERNEL_MODE="bwd" \
BENCHMARKING_METHOD="ELAPSED_TIME" \
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py
}

run_benchmarks() {
Expand Down

0 comments on commit 62e87e8

Please sign in to comment.