From 54fa1b9aa5a15cb284e8dbc8e0979677138cf01b Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Wed, 12 Feb 2025 22:53:45 -0500 Subject: [PATCH] [Mosaic GPU] Factor out arch specific Pallas Mosaic GPU tests --- jax/_src/test_util.py | 13 ++ tests/pallas/mosaic_gpu_test.py | 391 ++++++++++++++++---------------- 2 files changed, 210 insertions(+), 194 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 1acfdace2107..ce355c213aa8 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -472,6 +472,19 @@ def is_cuda_compute_capability_equal(capability: str) -> bool: current = tuple(int(x) for x in d.compute_capability.split(".")) return current == target + +class CudaArchSpecificTest: + """A mixin with methods allowing to skip arch specific tests.""" + + def skip_unless_sm90a(self): + if not is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + + def skip_unless_sm100a(self): + if not is_cuda_compute_capability_equal("10.0"): + self.skipTest("Only works on GPU with capability sm100a") + + def _get_device_tags(): """returns a set of tags defined for the device under test""" if is_device_rocm(): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 58066ec12b1d..c9b60d796a73 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -72,9 +72,12 @@ def capture_stdout(self): # We need to cudaDeviceSynchronize to make sure printfs are flushed. mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() - def skip_unless_sm90a(self): - if not jtu.is_cuda_compute_capability_equal("9.0"): - self.skipTest("Only works on GPU with capability sm90a") + +class PallasSm90ATest(PallasTest, jtu.CudaArchSpecificTest): + + def setUp(self): + self.skip_unless_sm90a() + super().setUp() class PallasCallTest(PallasTest): @@ -900,25 +903,6 @@ def body(step, xs): kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32) ) - @parameterized.parameters(False, True) - def test_fori_loop_accumulator(self, force_while): - self.skip_unless_sm90a() - - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) - @functools.partial( - pl.pallas_call, - in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], - out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), - ) - def kernel(i_ref, o_ref): - def scope(acc_ref): - return _fori_loop(force_while, 0, 4, lambda _, v: v + acc_ref[...], acc_ref[...]) - o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...])) - - acc_ini = jnp.ones((64, 64), dtype=jnp.float16) - np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - @parameterized.parameters(False, True) def test_fori_loop_indexed_store(self, force_while): @functools.partial( @@ -1016,9 +1000,193 @@ def kernel(x_ref, o_ref): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + def test_input_output_aliases(self): + # Note that we're writing to the input pointer, which should alias b_ptr. + def kernel(a_ref, b_ref): + del b_ref + a_ref[...] = jnp.ones_like(a_ref) + + a = np.zeros((64, 64), dtype=jnp.float32) + b = pl.pallas_call( + kernel, + in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), + input_output_aliases={0: 0}, + out_shape=a, + )(a) + np.testing.assert_array_equal(b, np.ones_like(a)) + + def test_slicing(self): + left = upper = slice(None, 64) + right = lower = slice(64, None) + # We rotate the four quadrants of the input clockwise. + def rotate(src, dst): + dst[upper, left] = src[lower, left] + dst[upper, right] = src[upper, left] + dst[lower, right] = src[upper, right] + dst[lower, left] = src[lower, right] + + x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) + spec = plgpu.GPUBlockSpec( + (128, 128), + lambda: (0, 0), + transforms=( + plgpu.TilingTransform((64, 64)), + plgpu.SwizzleTransform(128), + ), + ) + f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) + expected = np.empty_like(x) + rotate(x, expected) + np.testing.assert_array_equal(f(x), expected) + + def test_layout_cast(self, shape=(256, 64)): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + ) + def kernel(o_ref): + o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0, jnp.float32), plgpu.Layout.WGMMA) + + x = jnp.full(shape, 42.0, jnp.float32) + np.testing.assert_array_equal(kernel(), x) + + def test_profiler(self): + def kernel(x_ref, o_ref): + with jax.named_scope("add"): + with jax.named_scope("load"): + x = x_ref[...] + o = x + x + with jax.named_scope("store"): + o_ref[...] = o + with tempfile.TemporaryDirectory() as tmpdir: + x = jnp.arange(256).astype(jnp.float32) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + profile_space=16, profile_dir=tmpdir + ), + )(x) + jax.block_until_ready(y) + jax.effects_barrier() + [name] = os.listdir(tmpdir) + with open(os.path.join(tmpdir, name), "r") as f: + data = f.read() + self.assertEqual(data.count('"name": "add"'), 2) + self.assertEqual(data.count('"name": "load"'), 2) + self.assertEqual(data.count('"name": "store"'), 2) + np.testing.assert_array_equal(y, x + x) + + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast_convert_type(self, in_dtype, out_dtype): + m, n = 16, 8 + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) + + x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + + +class PallasCallSm90ATest(PallasSm90ATest): + + @parameterized.parameters(False, True) + def test_fori_loop_accumulator(self, force_while): + transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + @functools.partial( + pl.pallas_call, + in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), + out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), + ) + def kernel(i_ref, o_ref): + def scope(acc_ref): + return _fori_loop(force_while, 0, 4, lambda _, v: v + acc_ref[...], acc_ref[...]) + o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...])) + + acc_ini = jnp.ones((64, 64), dtype=jnp.float16) + np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) + + def test_realistic_matmul(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 4 + tile_m = tile_n = 128 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + def kernel(a_ref, b_ref, o_ref, acc_ref): + # Make sure tiling does not alter the shape of references + assert a_ref.shape == (tile_m, tile_k) + assert b_ref.shape == (tile_k, tile_n) + assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) + plgpu.wgmma(acc_ref, a_ref, b_ref) + is_last_step = pl.program_id(2) == grid_k - 1 + @pl.when(is_last_step) + def _epilogue(): + o_ref[...] = acc_ref[...].astype(dtype) + plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + transforms=( + plgpu.TilingTransform((elems_128b, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ), + ], + out_specs=plgpu.GPUBlockSpec( + (tile_m, tile_n), + lambda m, n, k: (m, n), + transforms=( + plgpu.TilingTransform((64, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], + grid=(grid_m, grid_n, grid_k), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "parallel", "sequential"], + max_concurrent_steps=2, + delay_release=1, + ), + )(a, b) + np.testing.assert_allclose(res, a @ b, rtol=1e-3) + @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): - self.skip_unless_sm90a() # TensorCores can only fuse transposes of 16-bit values, and RHS # is expected to be column major by default. rhs_transpose = jnp.dtype(dtype).itemsize != 2 @@ -1069,7 +1237,6 @@ def scope(acc_ref): ) def test_wgmma_registers(self): - self.skip_unless_sm90a() def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -1093,7 +1260,6 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) def test_wgmma_registers_init(self): - self.skip_unless_sm90a() def kernel(a_ref, b_ref, i_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -1118,7 +1284,6 @@ def scope(acc_ref): np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) def test_wgmma_sliced_ref(self): - self.skip_unless_sm90a() def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) @@ -1154,7 +1319,6 @@ def scope(acc_ref): np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) def test_wgmma_sliced_acc(self): - self.skip_unless_sm90a() swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref): @@ -1193,172 +1357,6 @@ def scope(acc_ref): )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) - def test_input_output_aliases(self): - # Note that we're writing to the input pointer, which should alias b_ptr. - def kernel(a_ref, b_ref): - del b_ref - a_ref[...] = jnp.ones_like(a_ref) - - a = np.zeros((64, 64), dtype=jnp.float32) - b = pl.pallas_call( - kernel, - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), - input_output_aliases={0: 0}, - out_shape=a, - )(a) - np.testing.assert_array_equal(b, np.ones_like(a)) - - def test_realistic_matmul(self): - self.skip_unless_sm90a() - dtype = jnp.float16 - swizzle = 128 - elems_128b = swizzle // jnp.dtype(dtype).itemsize - grid_m, grid_k, grid_n = 132, 10, 4 - tile_m = tile_n = 128 - tile_k = elems_128b - m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n - def kernel(a_ref, b_ref, o_ref, acc_ref): - # Make sure tiling does not alter the shape of references - assert a_ref.shape == (tile_m, tile_k) - assert b_ref.shape == (tile_k, tile_n) - assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) - plgpu.wgmma(acc_ref, a_ref, b_ref) - is_last_step = pl.program_id(2) == grid_k - 1 - @pl.when(is_last_step) - def _epilogue(): - o_ref[...] = acc_ref[...].astype(dtype) - plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 - - key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) - b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) - - res = pl.pallas_call( - kernel, - in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n, k: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), - scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], - grid=(grid_m, grid_n, grid_k), - compiler_params=plgpu.GPUCompilerParams( - dimension_semantics=["parallel", "parallel", "sequential"], - max_concurrent_steps=2, - delay_release=1, - ), - )(a, b) - np.testing.assert_allclose(res, a @ b, rtol=1e-3) - - def test_slicing(self): - left = upper = slice(None, 64) - right = lower = slice(64, None) - # We rotate the four quadrants of the input clockwise. - def rotate(src, dst): - dst[upper, left] = src[lower, left] - dst[upper, right] = src[upper, left] - dst[lower, right] = src[upper, right] - dst[lower, left] = src[lower, right] - - x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) - spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), - ) - f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) - expected = np.empty_like(x) - rotate(x, expected) - np.testing.assert_array_equal(f(x), expected) - - def test_layout_cast(self, shape=(256, 64)): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - ) - def kernel(o_ref): - o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0, jnp.float32), plgpu.Layout.WGMMA) - - x = jnp.full(shape, 42.0, jnp.float32) - np.testing.assert_array_equal(kernel(), x) - - def test_profiler(self): - def kernel(x_ref, o_ref): - with jax.named_scope("add"): - with jax.named_scope("load"): - x = x_ref[...] - o = x + x - with jax.named_scope("store"): - o_ref[...] = o - with tempfile.TemporaryDirectory() as tmpdir: - x = jnp.arange(256).astype(jnp.float32) - y = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - profile_space=16, profile_dir=tmpdir - ), - )(x) - jax.block_until_ready(y) - jax.effects_barrier() - [name] = os.listdir(tmpdir) - with open(os.path.join(tmpdir, name), "r") as f: - data = f.read() - self.assertEqual(data.count('"name": "add"'), 2) - self.assertEqual(data.count('"name": "load"'), 2) - self.assertEqual(data.count('"name": "store"'), 2) - np.testing.assert_array_equal(y, x + x) - - @parameterized.parameters( - (jnp.float16, jnp.float16), # Noop - (jnp.int16, jnp.bfloat16), - (jnp.int16, jnp.float16), - (jnp.uint16, jnp.float16), - (jnp.float32, jnp.int32), - (jnp.float32, jnp.uint32), - (jnp.uint32, jnp.int32), - (jnp.int32, jnp.uint32), - ) - def test_bitcast_convert_type(self, in_dtype, out_dtype): - m, n = 16, 8 - out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - grid = () - - @functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid) - def convert(x_ref, y_ref): - y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) - - x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) - y = convert(x) - y_ref = jax.lax.bitcast_convert_type(x, out_dtype) - np.testing.assert_array_equal(y, y_ref) - class PipelineTest(PallasTest): @@ -1616,8 +1614,10 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + +class PipelineSm90ATest(PallasSm90ATest): + def test_realistic_matmul(self): - self.skip_unless_sm90a() dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -2046,6 +2046,9 @@ def compute(l_smem, r_smem, o_smem): out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) np.testing.assert_allclose(out, x + x) + +class ExamplesSm90ATest(PallasSm90ATest): + # WGMMA def test_stage6(self): m_block = n_block = 64