diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index e9461a5ceba0..e5b491aef330 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -78,6 +78,7 @@ pytype_strict_library( "//jax:dtypes", "//jax:effects", "//jax:mosaic_gpu", + "//jax:state_types", "//jax:tree_util", "//jax/_src/pallas", "//jaxlib/mlir:ir", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 36e6e47cbf4f..42cf59ec6302 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -24,6 +24,7 @@ import itertools as it from typing import Any, ClassVar, Literal +import jax from jax._src import core as jax_core from jax._src import dtypes from jax._src import effects @@ -31,6 +32,7 @@ from jax._src.pallas import core as pallas_core from jax._src.state import indexing from jax._src.state import types as state_types +from jax._src.state import discharge as state_discharge import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp from jaxlib.mlir import ir @@ -114,6 +116,24 @@ def __call__( return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) +def kernel(body, out_shape, compiler_params=None, **mesh_kwargs): + if unwrap_out := not isinstance(out_shape, (tuple, list)): + out_shape = (out_shape,) + def wrapper(*operands): + def stateful(operand_and_out_refs): + operand_refs, out_refs = operand_and_out_refs + def cmap_body(): + body(*operand_refs, *out_refs) + pallas_core.core_map( + GPUMesh(**mesh_kwargs), compiler_params=compiler_params + )(cmap_body) + _, outs = state_discharge.run_state(stateful)( + (operands, jax.tree.map(jnp.zeros_like, out_shape)) + ) + return outs[0] if unwrap_out else outs + return wrapper + + @dataclasses.dataclass(frozen=True) class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () @@ -487,7 +507,7 @@ class GPUMesh: def __post_init__(self): if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): raise ValueError("Need as many axis names as grid dimensions + warp groups") - if self.num_threads > 2048 // 128: + if self.num_threads is not None and self.num_threads > 2048 // 128: raise ValueError( "Requested too many CUDA threads per block. Each Mosaic thread" " corresponds to 128 CUDA threads." @@ -529,8 +549,6 @@ def _gpu_mesh_discharge_rule( raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") if mesh.cluster: raise NotImplementedError - if mesh.num_threads is None: - raise NotImplementedError if compiler_params and not isinstance(compiler_params, GPUCompilerParams): raise TypeError( "Compiler params must be a GPUCompilerParams, got" diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 32b56a332450..a95a4b74c358 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -546,6 +546,9 @@ def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr, nonlocal_effects.add( eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index]) ) + assert len(jaxpr.invars) == len(is_initialized) + if not all(is_initialized): + raise NotImplementedError # Uninitialized refs are not in avals. return avals, nonlocal_effects run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 8da2a5095927..b3fb5c2efd92 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -28,6 +28,7 @@ from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.core import kernel as kernel from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 31de258b2a2f..f1c0a1ca53d6 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -190,51 +190,46 @@ def kv_loop(kv_step, _): plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) - def run(refs): - q_ref, k_ref, v_ref, out_ref = refs - - num_q_tiles, rem = divmod(q_seq_len, block_q * 2) - if rem: - raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - mesh = plgpu.GPUMesh( - grid=(batch_size, num_q_tiles, num_q_heads), - num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), + def entry(q_ref, k_ref, v_ref, out_ref): + compute_wgs = 2 + tiling = plgpu.TilingTransform((64, 64)) + swizzle = plgpu.SwizzleTransform(128) + qo_scratch = plgpu.SMEM( + (compute_wgs, block_q, head_dim), jnp.float16, + transforms=(tiling, swizzle), ) - - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) + k_scratch = plgpu.SMEM( + (max_concurrent_steps, block_kv, head_dim), jnp.float16, + transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), ) - def _kernel_entry(): - compute_wgs = 2 - tiling = plgpu.TilingTransform((64, 64)) - swizzle = plgpu.SwizzleTransform(128) - qo_scratch = plgpu.SMEM( - (compute_wgs, block_q, head_dim), jnp.float16, - transforms=(tiling, swizzle), - ) - k_scratch = plgpu.SMEM( - (max_concurrent_steps, block_kv, head_dim), jnp.float16, - transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), - ) - v_scratch = plgpu.SMEM( - (max_concurrent_steps, block_kv, head_dim), jnp.float16, - transforms=(tiling, swizzle), - ) - pl.run_scoped( - lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), - (qo_scratch, k_scratch, v_scratch), - ( - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=compute_wgs), - ), - (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, - plgpu.Barrier(num_arrivals=compute_wgs), - ) + v_scratch = plgpu.SMEM( + (max_concurrent_steps, block_kv, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + pl.run_scoped( + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), + (qo_scratch, k_scratch, v_scratch), + ( + plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(1, num_barriers=compute_wgs), + ), + (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, + plgpu.Barrier(num_arrivals=compute_wgs), + ) + + num_q_tiles, rem = divmod(q_seq_len, block_q * 2) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) - return out + return plgpu.kernel( + entry, + out_shape=q, + grid=(batch_size, num_q_tiles, num_q_heads), + num_threads=3, + axis_names=("batch", "q_seq", "heads", "wg"), + compiler_params=plgpu.GPUCompilerParams(approx_math=True), + )(q, k, v) @jax.jit diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 770e255489e4..cad58b32b0bf 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax +from jax import lax from jax._src import test_util as jtu from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline from jax.experimental import pallas as pl @@ -1838,5 +1839,119 @@ def scoped(barrier): np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) +class ExamplesTest(PallasTest): + + # Basic + def test_stage0(self): + def body(l_ref, r_ref, o_ref): + o_ref[...] = l_ref[...] + r_ref[...] + + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + out = plgpu.kernel(body, out_shape=x)(x, x) + np.testing.assert_allclose(out, x + x) + + # Multi-block kernels + def test_stage1(self): + row_block = 64 + def body(l_ref, r_ref, o_ref): + my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) + o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] + + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + np.testing.assert_allclose(out, x + x) + + # Async copies + def test_stage3(self): + row_block, col_block = 64, 128 + def body(l_ref, r_ref, o_ref): + my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) + def scoped(l_smem, r_smem, o_smem, barrier): + plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) + plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) + plgpu.barrier_wait(barrier) + o_smem[...] = l_smem[...] + r_smem[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped( + scoped, + *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), + plgpu.Barrier(num_arrivals=2), + ) + + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + np.testing.assert_allclose(out, x + x) + + # Pipelining + def test_stage4(self): + row_block, col_block = 64, 32 + def body(l_ref, r_ref, o_ref): + def compute(l_smem, r_smem, o_smem): + o_smem[...] = l_smem[...] + r_smem[...] + r = lax.axis_index("rows") + block = pl.BlockSpec((row_block, col_block), lambda c: (r, c)) + plgpu.emit_pipeline( + compute, + grid=(l_ref.shape[1] // col_block,), + in_specs=[block] * 2, + out_specs=[block], + )(l_ref, r_ref, o_ref) + + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + np.testing.assert_allclose(out, x + x) + + # Transforms + def test_stage5(self): + row_block, col_block = 64, 32 + def body(l_ref, r_ref, o_ref): + def compute(l_smem, r_smem, o_smem): + o_smem[...] = l_smem[...] + r_smem[...] + r = lax.axis_index("rows") + block = plgpu.GPUBlockSpec( + (row_block, col_block), lambda c: (r, c), + transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)), + ) + plgpu.emit_pipeline( + compute, + grid=(l_ref.shape[1] // col_block,), + in_specs=[block] * 2, + out_specs=[block], + )(l_ref, r_ref, o_ref) + + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) + np.testing.assert_allclose(out, x + x) + + # WGMMA + def test_stage6(self): + m_block = n_block = 64 + k_block = 32 + def body(l_ref, r_ref, o_ref): + def compute(l_smem, r_smem, o_smem): + def do_wgmma(acc_ref): + plgpu.wgmma(acc_ref, l_smem, r_smem) + return acc_ref[...] + o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) + m, n = lax.axis_index("m"), lax.axis_index("n") + lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)) + r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64)) + plgpu.emit_pipeline( + compute, + grid=(l_ref.shape[1] // k_block,), + in_specs=[plgpu.GPUBlockSpec((m_block, k_block), lambda k: (m, k), transforms=lo_transforms), + plgpu.GPUBlockSpec((k_block, n_block), lambda k: (k, n), transforms=r_transforms)], + out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)], + )(l_ref, r_ref, o_ref) + + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x) + np.testing.assert_allclose(out, x @ x) + + # TODO(apaszke): Clusters and multicast + + if __name__ == "__main__": absltest.main()