Skip to content

Commit

Permalink
[FRONTEND] Fix warmup function for autotuner (#2731)
Browse files Browse the repository at this point in the history
The previous use of the keyword `warmup` led to a naming conflict, as it
was used for both a variable and a function. To resolve this conflict,
it has now been renamed to `num_warmups`.

Also skipped running kernels in `test_line_info.py` to reduce CI time
  • Loading branch information
Jokeren authored Nov 30, 2023
1 parent b9a6687 commit 42e5d38
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
12 changes: 5 additions & 7 deletions python/test/unit/language/test_line_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,17 @@ def test_line_info(func: str):
pytest.skip("nvdisasm is not available")

shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.float32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel_info = {}
if func == "single":
kernel_info = kernel_single[(1,)](x, y, BLOCK=shape[0])
kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,))
elif func == "call":
kernel_info = kernel_call[(1,)](x, y, BLOCK=shape[0])
kernel_info = kernel_call.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,))
elif func == "call_noinline":
kernel_info = kernel_call_noinline[(1,)](x, y, BLOCK=shape[0])
kernel_info = kernel_call_noinline.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,))
elif func == "multi_files":
kernel_info = kernel_multi_files[(1,)](x, y, BLOCK=shape[0])
kernel_info = kernel_multi_files.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,))
elif func == "autotune":
kernel_info = kernel_autotune[(1,)](x, y, SIZE=shape[0])
kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0]

file_lines = extract_file_lines(kernel_info.asm["cubin"])
if func == "single":
Expand Down
34 changes: 20 additions & 14 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def _post_hook(args):
self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune)

self.fn = fn
self.warmup = warmup
self.rep = rep
self.num_warmups = warmup
self.num_reps = rep

def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
Expand All @@ -113,13 +113,14 @@ def kernel_call():
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
# TODO: Make it configurable
# enable_persistent=False,
**current,
)
self.post_hook(args)

try:
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
except OutOfResources:
return [float("inf"), float("inf"), float("inf")]

Expand Down Expand Up @@ -184,7 +185,8 @@ def prune_configs(self, kwargs):
num_warps=config.num_warps,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent,
# TODO: Make it configurable
# enable_persistent=False,
)
for config in pruned_configs
}
Expand All @@ -193,18 +195,22 @@ def prune_configs(self, kwargs):

def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
ret = []
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
num_stages=config.num_stages,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent,
**kwargs,
**config.kwargs,
)
ret.append(
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_ctas=config.num_ctas,
num_stages=config.num_stages,
enable_warp_specialization=config.enable_warp_specialization,
# TODO: Make it configurable
# enable_persistent=False,
**kwargs,
**config.kwargs,
))
self.nargs = None
return ret


class Config:
Expand Down

0 comments on commit 42e5d38

Please sign in to comment.