Skip to content

Commit

Permalink
[Triton] Change xor_sum to use @jit (NFC) (#5769)
Browse files Browse the repository at this point in the history
The interpreter doesn't support the `_generator` special argument at the
moment.
  • Loading branch information
Mogball authored Jan 31, 2025
1 parent d2c8852 commit d102143
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions python/triton/language/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,11 @@ def _xor_combine(a, b):


@core._tensor_member_fn
@core.builtin
@jit
@core._add_reduction_docstr("xor sum")
def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None):
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")

input = core._promote_bfloat16_to_float32(input, _builder=_builder)
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator)
def xor_sum(input, axis=None, keep_dims=False):
core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers")
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)


# cumsum
Expand Down

0 comments on commit d102143

Please sign in to comment.