Skip to content

Commit

Permalink
[pallas:triton] Added basic support for lax.concatenate
Browse files Browse the repository at this point in the history
The corresponding Triton op is restricted to `jnp.stack([x, y], axis=-1)`,
so the lowering only supports that case for now.

See #25321.

PiperOrigin-RevId: 726881284
  • Loading branch information
superbobry authored and Google-ML-Automation committed Feb 14, 2025
1 parent 80dcb7b commit 3162cc4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
38 changes: 30 additions & 8 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,24 +1733,46 @@ def _reshape_lowering_rule(
return ValueError("`dimensions` is not supported.")

a = _ensure_ir_value(a, *ctx.avals_in)
[a_aval] = ctx.avals_in
[out_aval] = ctx.avals_out
if not ir.RankedTensorType.isinstance(a.type):
assert all(dim_size == 1 for dim_size in out_aval.shape)
return _splat(a, out_aval.shape)
# Triton Reshape doesn't support scalar result types (only 0d tensors).
if out_aval.ndim == 0:
return _reduce_lowering(jnp.add, ctx, a, axes=tuple(range(a_aval.ndim)))
return _reshape(a, out_aval.shape)

ty = ir.RankedTensorType(a.type)

# Triton Reshape doesn't support scalar result types (only 0d tensors).
if not out_aval.shape:
return _reduce_lowering(jnp.add, ctx, a, axes=tuple(range(ty.rank)))
def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value:
if not ir.RankedTensorType.isinstance(a.type):
assert all(dim_size == 1 for dim_size in shape)
return _splat(a, shape)

ty = ir.RankedTensorType(a.type)
return tt_dialect.reshape(
ir.RankedTensorType.get([*out_aval.shape], ty.element_type, ty.encoding),
ir.RankedTensorType.get(shape, ty.element_type, ty.encoding),
a,
allow_reorder=False,
)


@register_lowering(lax.concatenate_p)
def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension):
if len(args) != 2:
raise NotImplementedError("Only 2-argument concatenate is supported.")
x_aval, y_aval = ctx.avals_in
x, y = args
if dimension != x_aval.ndim-1:
raise NotImplementedError(
"Only concatenate along the last dimension is supported."
)
if x_aval.shape[-1] != 1 or y_aval.shape[-1] != 1:
raise NotImplementedError(
"Only arguments with shape [..., 1] are supported."
)
return tt_dialect.join(
_reshape(x, x_aval.shape[:-1]), _reshape(y, y_aval.shape[:-1])
)


def _compute_offsets_from_indices(
block_info: BlockInfo, nd_indexer: NDIndexer
) -> ir.Value:
Expand Down
15 changes: 10 additions & 5 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,20 +753,25 @@ def kernel(x_ref, y_ref):
x = np.arange(1024, dtype=jnp.float32).reshape(8, 128) + 10
self.assertAllClose(f(x).item(), 10.0)

@jtu.skip_on_devices("gpu") # TODO: not implemented
def test_concat_constant(self):
if pltpu is None:
if pltpu is None and jtu.test_device_matches(["tpu"]):
self.skipTest("No TPU module available.")
axis = 0
num_arrays = 16
if jtu.test_device_matches(["gpu"]) and not self.INTERPRET:
# Triton only supports concatenation along the last dimension.
num_arrays = 2
axis = -1
def kernel(out):
result = []
for i in range(16):
for i in range(num_arrays):
result.append(jnp.full((1, 128), i, jnp.float32))
out[:] = jnp.stack(result).reshape(16, 128)
out[:] = jnp.stack(result, axis=axis).reshape(num_arrays, 128)

def run(interpret=False):
return pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32),
out_shape=jax.ShapeDtypeStruct((num_arrays, 128), jnp.float32),
out_specs=pl.BlockSpec(memory_space=pltpu.VMEM),
interpret=interpret,
)()
Expand Down

0 comments on commit 3162cc4

Please sign in to comment.