Skip to content

Commit

Permalink
Integrate Triton up to [632bfc3](https://github.com/openai/triton/com…
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg authored and Google-ML-Automation committed Jan 21, 2025
1 parent 907555f commit 59f5703
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 91 deletions.
89 changes: 36 additions & 53 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@
jnp.dtype("uint32"): "u32",
jnp.dtype("uint16"): "u16",
jnp.dtype("uint8"): "u8",
# Triton defines a 'B' type, which is an alias for both i1 and bool.
jnp.dtype("bool"): "B",
jnp.dtype("bool"): "i1",
}

Grid = Union[int, tuple[int], tuple[int, int], tuple[int, int, int]]
Expand Down Expand Up @@ -353,30 +352,36 @@ def get_or_create_triton_kernel(
if num_ctas > 1 and compute_capability < 90:
raise ValueError("num_ctas > 1 unsupported before Hopper.")

backend = backend_init_func(device, compute_capability)

signature = {fn.arg_names[i]: v for i, v in enumerate(arg_dtypes)}
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
# We assume that all arrays are aligned to 16 bytes, and Triton may use this
# assumption, unless array args are include in the `do_not_specialize` list.
# We replace array arguments with mock Torch tensors, to allow us to use
# `JITFunction._get_config` to get the specialization_attr.
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
backend = backend_init_func(device, compute_capability)
for i, _, v in scalar_args:
args_for_specialization_attr[i] = v

specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access
specialization = [
triton.runtime.jit.specialize_impl(
types.SimpleNamespace(
data_ptr=lambda: 16, dtype=arg_dtype.removeprefix("*")
),
backend.get_arg_specialization,
)
for arg_dtype in arg_dtypes
]
attrs = {
fn.arg_names[i]: backend.parse_attr(attr)
for i, (_, attr) in enumerate(specialization)
}
constants = dict(metaparams)
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for (i,) in specialization_attr.equal_to_1})
constants.update({fn.arg_names[i]: 1 for i, _, v in scalar_args if v == 1})
for constant in constants:
signature[constant] = "constexpr"

# Cache key should contain any parameter that can affect the compiler output.
cache_key = (
fn,
tuple(signature.items()),
tuple(specialization_attr.get_fn_attrs()),
tuple(specialization),
tuple(constants.items()),
num_warps,
num_stages,
Expand Down Expand Up @@ -408,46 +413,22 @@ def get_or_create_triton_kernel(
context = _triton.ir.context()
_triton.ir.load_dialects(context)
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()

module = (
code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
fn,
constexprs=constants,
signature=signature,
attrs=specialization_attr,
),
options=options,
codegen_fns=codegen_fns,
context=context,
module_map=backend.get_module_map(),
)
if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args
# Triton changes ASTSource.ast_to_ttir to include module_map. Handle
# backward compatibility here.
else code_gen.ast_to_ttir(
fn,
specialization=tc.ASTSource(
fn,
constexprs=constants,
signature=signature,
attrs=specialization_attr,
),
options=options,
codegen_fns=codegen_fns,
context=context,
)
codegen_fns = backend.get_codegen_implementation(options)

module = code_gen.ast_to_ttir(
fn,
tc.ASTSource(
fn, constexprs=constants, signature=signature, attrs=attrs
),
options=options,
codegen_fns=codegen_fns,
context=context,
module_map=backend.get_module_map(),
)
ttir = str(module)

compilation_result = compile_ttir_inplace(
module,
backend,
options,
compute_capability,
platform
module, backend, options, compute_capability, platform
)

kernel_name = compilation_result.name
Expand All @@ -459,7 +440,7 @@ def get_or_create_triton_kernel(
with open(
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ptx", "w"
) as f:
f.write(compilation_result.ptx)
f.write(compilation_result.binary)
with open(
f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/{kernel_name}.ttgir", "w"
) as f:
Expand Down Expand Up @@ -490,7 +471,7 @@ def get_or_create_triton_kernel(

_COMPILED_KERNEL_CACHE[cache_key] = kernel

return kernel, specialization_attr
return kernel, attrs


def triton_kernel_call_lowering(
Expand Down Expand Up @@ -628,15 +609,17 @@ def prune_configs(configs, named_args, **kwargs):

kernel_params = []
zeroed_params_with_sizes = dict(params["zeroed_params_with_sizes"])
equal_to_1 = {i for i, _, v in scalar_args if v == 1}
for i, (arg, dtype) in enumerate(zip(args, arg_dtypes)):
if isinstance(arg, core.ShapedArray):
arg_attrs = specialization_attr[fn.arg_names[i]]
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
16 if (i in specialization_attr.divisibility_16) else 0,
16 if (["tt.divisibility", 16] in arg_attrs) else 0,
)
)
elif (i,) not in specialization_attr.equal_to_1:
elif i not in equal_to_1:
kernel_params.append(
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
)
Expand Down
38 changes: 0 additions & 38 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,44 +531,6 @@ def test_autotune_with_input_output_aliasing(self):
out = add(x, y, kernel=kernel, input_output_aliases={0: 0})
np.testing.assert_allclose(out, expected)

def test_specialization(self):
do_not_specialize = (
0, # a_ptr
2, # M
6, # stride_ak
7, # stride_bk
11, # c_ptr
)
kernel = triton.jit(do_not_specialize=do_not_specialize)(matmul_kernel.fn)

m, n, k = 128, 128, 99
x, y = create_random_inputs([m, k], [k, n])

with mock.patch.object(code_gen, "ast_to_ttir") as mock_compile:
try:
_ = matmul(
x,
y,
kernel=kernel,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
# K_EXACTLY_DIVISIBLE_BY_BLOCK=False,
)
except TypeError:
pass # Error thrown as the mocked method's return value is invalid.

mock_compile.assert_called_once()
specialization = mock_compile.call_args[1]['specialization']

# Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`.
# However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not
# specialize", leaving `b_ptr`, `N`, and `stride_cm`.
self.assertEqual(specialization.attrs.divisibility_16, [(1,), (3,), (9,)])
# `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.attrs.equal_to_1, [(8,), (10,)])


if __name__ == "__main__":
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
Expand Down

0 comments on commit 59f5703

Please sign in to comment.