Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas:triton] Added basic support for lax.concatenate #26533

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,24 +1733,46 @@ def _reshape_lowering_rule(
return ValueError("`dimensions` is not supported.")

a = _ensure_ir_value(a, *ctx.avals_in)
[a_aval] = ctx.avals_in
[out_aval] = ctx.avals_out
if not ir.RankedTensorType.isinstance(a.type):
assert all(dim_size == 1 for dim_size in out_aval.shape)
return _splat(a, out_aval.shape)
# Triton Reshape doesn't support scalar result types (only 0d tensors).
if out_aval.ndim == 0:
return _reduce_lowering(jnp.add, ctx, a, axes=tuple(range(a_aval.ndim)))
return _reshape(a, out_aval.shape)

ty = ir.RankedTensorType(a.type)

# Triton Reshape doesn't support scalar result types (only 0d tensors).
if not out_aval.shape:
return _reduce_lowering(jnp.add, ctx, a, axes=tuple(range(ty.rank)))
def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value:
if not ir.RankedTensorType.isinstance(a.type):
assert all(dim_size == 1 for dim_size in shape)
return _splat(a, shape)

ty = ir.RankedTensorType(a.type)
return tt_dialect.reshape(
ir.RankedTensorType.get([*out_aval.shape], ty.element_type, ty.encoding),
ir.RankedTensorType.get(shape, ty.element_type, ty.encoding),
a,
allow_reorder=False,
)


@register_lowering(lax.concatenate_p)
def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension):
if len(args) != 2:
raise NotImplementedError("Only 2-argument concatenate is supported.")
x_aval, y_aval = ctx.avals_in
x, y = args
if dimension != x_aval.ndim-1:
raise NotImplementedError(
"Only concatenate along the last dimension is supported."
)
if x_aval.shape[-1] != 1 or y_aval.shape[-1] != 1:
raise NotImplementedError(
"Only arguments with shape [..., 1] are supported."
)
return tt_dialect.join(
_reshape(x, x_aval.shape[:-1]), _reshape(y, y_aval.shape[:-1])
)


def _compute_offsets_from_indices(
block_info: BlockInfo, nd_indexer: NDIndexer
) -> ir.Value:
Expand Down
15 changes: 10 additions & 5 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,20 +753,25 @@ def kernel(x_ref, y_ref):
x = np.arange(1024, dtype=jnp.float32).reshape(8, 128) + 10
self.assertAllClose(f(x).item(), 10.0)

@jtu.skip_on_devices("gpu") # TODO: not implemented
def test_concat_constant(self):
if pltpu is None:
if pltpu is None and jtu.test_device_matches(["tpu"]):
self.skipTest("No TPU module available.")
axis = 0
num_arrays = 16
if jtu.test_device_matches(["gpu"]) and not self.INTERPRET:
# Triton only supports concatenation along the last dimension.
num_arrays = 2
axis = -1
def kernel(out):
result = []
for i in range(16):
for i in range(num_arrays):
result.append(jnp.full((1, 128), i, jnp.float32))
out[:] = jnp.stack(result).reshape(16, 128)
out[:] = jnp.stack(result, axis=axis).reshape(num_arrays, 128)

def run(interpret=False):
return pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32),
out_shape=jax.ShapeDtypeStruct((num_arrays, 128), jnp.float32),
out_specs=pl.BlockSpec(memory_space=pltpu.VMEM),
interpret=interpret,
)()
Expand Down
Loading