Skip to content

Commit

Permalink
Merge commit 'a637eb292334caca84eef34db53521da5de48bbd'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Feb 2, 2025
2 parents 0e888dc + a637eb2 commit b9ba137
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 32 deletions.
1 change: 1 addition & 0 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def filter_traceback(e: BaseException):
f"{sep}triton{sep}compiler{sep}code_generator.py",
f"{sep}ast.py",
]
BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES]

tb = e.__traceback__
frames = []
Expand Down
56 changes: 24 additions & 32 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ def matmul(a, b):
return c


@triton.jit
def _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return tile_id, pid_m, pid_n


@triton.autotune(
configs=matmul_get_configs(),
key=["M", "N", "K"],
Expand Down Expand Up @@ -264,13 +275,7 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m

tile_id, pid_m, pid_n = _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
Expand Down Expand Up @@ -377,13 +382,7 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m

tile_id, pid_m, pid_n = _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N

Expand All @@ -394,12 +393,8 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
accumulator = tl.dot(a, b.T, accumulator)

if ki == k_tiles - 1:
tile_id_c += NUM_SMS
group_id = tile_id_c // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id_c % group_size_m)
pid_n = (tile_id_c % num_pid_in_group) // group_size_m
tile_id_c, pid_m, pid_n = _compute_tile_and_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M,
NUM_SMS)

offs_am_c = pid_m * BLOCK_SIZE_M
offs_bn_c = pid_n * BLOCK_SIZE_N
Expand Down Expand Up @@ -541,10 +536,9 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
tiles_per_SM += 1

tile_id = start_pid - NUM_SMS
tile_id_c = start_pid - NUM_SMS
ki = -1

pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0

Expand All @@ -555,13 +549,7 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:

tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
tile_id, pid_m, pid_n = _compute_tile_and_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)

offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
Expand All @@ -573,18 +561,22 @@ def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, #
accumulator = tl.dot(a, b.T, accumulator)

if ki == k_tiles - 1:
tile_id_c, pid_m, pid_n = _compute_tile_and_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M,
NUM_SMS)
offs_cm = pid_m * BLOCK_SIZE_M
offs_cn = pid_n * BLOCK_SIZE_N

if EPILOGUE_SUBTILE:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
c0 = acc0.to(dtype)
c_desc.store([offs_am, offs_bn], c0)
c_desc.store([offs_cm, offs_cn], c0)
c1 = acc1.to(dtype)
c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c1)
c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1)
else:
c = accumulator.to(dtype)
c_desc.store([offs_am, offs_bn], c)
c_desc.store([offs_cm, offs_cn], c)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

Expand Down

0 comments on commit b9ba137

Please sign in to comment.