Skip to content

Commit

Permalink
[Pallas:MGPU] Add helpers to make writing core_map kernels less verbose
Browse files Browse the repository at this point in the history
Also add small "getting started" examples that use the helpers in tests.

PiperOrigin-RevId: 718824070
  • Loading branch information
apaszke authored and Google-ML-Automation committed Jan 24, 2025
1 parent 33ec629 commit ea800a5
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 45 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic_gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pytype_strict_library(
"//jax:dtypes",
"//jax:effects",
"//jax:mosaic_gpu",
"//jax:state_types",
"//jax:tree_util",
"//jax/_src/pallas",
"//jaxlib/mlir:ir",
Expand Down
24 changes: 21 additions & 3 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
import itertools as it
from typing import Any, ClassVar, Literal

import jax
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import effects
from jax._src import tree_util
from jax._src.pallas import core as pallas_core
from jax._src.state import indexing
from jax._src.state import types as state_types
from jax._src.state import discharge as state_discharge
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
from jaxlib.mlir import ir
Expand Down Expand Up @@ -114,6 +116,24 @@ def __call__(
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms)


def kernel(body, out_shape, compiler_params=None, **mesh_kwargs):
if unwrap_out := not isinstance(out_shape, (tuple, list)):
out_shape = (out_shape,)
def wrapper(*operands):
def stateful(operand_and_out_refs):
operand_refs, out_refs = operand_and_out_refs
def cmap_body():
body(*operand_refs, *out_refs)
pallas_core.core_map(
GPUMesh(**mesh_kwargs), compiler_params=compiler_params
)(cmap_body)
_, outs = state_discharge.run_state(stateful)(
(operands, jax.tree.map(jnp.zeros_like, out_shape))
)
return outs[0] if unwrap_out else outs
return wrapper


@dataclasses.dataclass(frozen=True)
class GPUMemoryRef(pallas_core.MemoryRef):
transforms: Sequence[MemoryRefTransform] = ()
Expand Down Expand Up @@ -487,7 +507,7 @@ class GPUMesh:
def __post_init__(self):
if len(self.axis_names) != len(self.grid) + (self.num_threads is not None):
raise ValueError("Need as many axis names as grid dimensions + warp groups")
if self.num_threads > 2048 // 128:
if self.num_threads is not None and self.num_threads > 2048 // 128:
raise ValueError(
"Requested too many CUDA threads per block. Each Mosaic thread"
" corresponds to 128 CUDA threads."
Expand Down Expand Up @@ -529,8 +549,6 @@ def _gpu_mesh_discharge_rule(
raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}")
if mesh.cluster:
raise NotImplementedError
if mesh.num_threads is None:
raise NotImplementedError
if compiler_params and not isinstance(compiler_params, GPUCompilerParams):
raise TypeError(
"Compiler params must be a GPUCompilerParams, got"
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr,
nonlocal_effects.add(
eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index])
)
assert len(jaxpr.invars) == len(is_initialized)
if not all(is_initialized):
raise NotImplementedError # Uninitialized refs are not in avals.
return avals, nonlocal_effects
run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval)

Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import kernel as kernel
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
Expand Down
79 changes: 37 additions & 42 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,51 +190,46 @@ def kv_loop(kv_step, _):
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot])
lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None)

def run(refs):
q_ref, k_ref, v_ref, out_ref = refs

num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
mesh = plgpu.GPUMesh(
grid=(batch_size, num_q_tiles, num_q_heads),
num_threads=3,
axis_names=("batch", "q_seq", "heads", "wg"),
def entry(q_ref, k_ref, v_ref, out_ref):
compute_wgs = 2
tiling = plgpu.TilingTransform((64, 64))
swizzle = plgpu.SwizzleTransform(128)
qo_scratch = plgpu.SMEM(
(compute_wgs, block_q, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)

@pl.core_map(
mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True)
k_scratch = plgpu.SMEM(
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle),
)
def _kernel_entry():
compute_wgs = 2
tiling = plgpu.TilingTransform((64, 64))
swizzle = plgpu.SwizzleTransform(128)
qo_scratch = plgpu.SMEM(
(compute_wgs, block_q, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
k_scratch = plgpu.SMEM(
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle),
)
v_scratch = plgpu.SMEM(
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
pl.run_scoped(
lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args),
(qo_scratch, k_scratch, v_scratch),
(
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
plgpu.Barrier(1, num_barriers=compute_wgs),
),
(plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2,
plgpu.Barrier(num_arrivals=compute_wgs),
)
v_scratch = plgpu.SMEM(
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
pl.run_scoped(
lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args),
(qo_scratch, k_scratch, v_scratch),
(
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
plgpu.Barrier(1, num_barriers=max_concurrent_steps),
plgpu.Barrier(1, num_barriers=compute_wgs),
),
(plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2,
plgpu.Barrier(num_arrivals=compute_wgs),
)

num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")

_, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf)))
return out
return plgpu.kernel(
entry,
out_shape=q,
grid=(batch_size, num_q_tiles, num_q_heads),
num_threads=3,
axis_names=("batch", "q_seq", "heads", "wg"),
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
)(q, k, v)


@jax.jit
Expand Down
115 changes: 115 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
from jax.experimental import pallas as pl
Expand Down Expand Up @@ -1838,5 +1839,119 @@ def scoped(barrier):
np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128))


class ExamplesTest(PallasTest):

# Basic
def test_stage0(self):
def body(l_ref, r_ref, o_ref):
o_ref[...] = l_ref[...] + r_ref[...]

x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x)(x, x)
np.testing.assert_allclose(out, x + x)

# Multi-block kernels
def test_stage1(self):
row_block = 64
def body(l_ref, r_ref, o_ref):
my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block)
o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice]

x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)

# Async copies
def test_stage3(self):
row_block, col_block = 64, 128
def body(l_ref, r_ref, o_ref):
my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block)
def scoped(l_smem, r_smem, o_smem, barrier):
plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier)
plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier)
plgpu.barrier_wait(barrier)
o_smem[...] = l_smem[...] + r_smem[...]
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice])
plgpu.wait_smem_to_gmem(0)
pl.run_scoped(
scoped,
*([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3),
plgpu.Barrier(num_arrivals=2),
)

x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)

# Pipelining
def test_stage4(self):
row_block, col_block = 64, 32
def body(l_ref, r_ref, o_ref):
def compute(l_smem, r_smem, o_smem):
o_smem[...] = l_smem[...] + r_smem[...]
r = lax.axis_index("rows")
block = pl.BlockSpec((row_block, col_block), lambda c: (r, c))
plgpu.emit_pipeline(
compute,
grid=(l_ref.shape[1] // col_block,),
in_specs=[block] * 2,
out_specs=[block],
)(l_ref, r_ref, o_ref)

x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)

# Transforms
def test_stage5(self):
row_block, col_block = 64, 32
def body(l_ref, r_ref, o_ref):
def compute(l_smem, r_smem, o_smem):
o_smem[...] = l_smem[...] + r_smem[...]
r = lax.axis_index("rows")
block = plgpu.GPUBlockSpec(
(row_block, col_block), lambda c: (r, c),
transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)),
)
plgpu.emit_pipeline(
compute,
grid=(l_ref.shape[1] // col_block,),
in_specs=[block] * 2,
out_specs=[block],
)(l_ref, r_ref, o_ref)

x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
np.testing.assert_allclose(out, x + x)

# WGMMA
def test_stage6(self):
m_block = n_block = 64
k_block = 32
def body(l_ref, r_ref, o_ref):
def compute(l_smem, r_smem, o_smem):
def do_wgmma(acc_ref):
plgpu.wgmma(acc_ref, l_smem, r_smem)
return acc_ref[...]
o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16))
m, n = lax.axis_index("m"), lax.axis_index("n")
lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64))
r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64))
plgpu.emit_pipeline(
compute,
grid=(l_ref.shape[1] // k_block,),
in_specs=[plgpu.GPUBlockSpec((m_block, k_block), lambda k: (m, k), transforms=lo_transforms),
plgpu.GPUBlockSpec((k_block, n_block), lambda k: (k, n), transforms=r_transforms)],
out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)],
)(l_ref, r_ref, o_ref)

x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128)
out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x)
np.testing.assert_allclose(out, x @ x)

# TODO(apaszke): Clusters and multicast


if __name__ == "__main__":
absltest.main()

0 comments on commit ea800a5

Please sign in to comment.