Skip to content

Commit

Permalink
[pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types a…
Browse files Browse the repository at this point in the history
…nymore.

PiperOrigin-RevId: 722615539
  • Loading branch information
cperivol authored and Google-ML-Automation committed Feb 7, 2025
1 parent 5bc17f7 commit f2bf47f
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 15 deletions.
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,7 @@ def _ensure_idx_fa(x):
i32 = ir.IntegerType.get_signless(32)
if isinstance(x, ir.Value):
return mgpu.FragmentedArray.splat(
x, (), is_signed=mgpu.utils.is_signed(x.type)
x, (), is_signed=ir.IntegerType.isinstance(x.type) or None
).astype(i32, is_signed=False)
if isinstance(x, mgpu.FragmentedArray):
return x.astype(i32, is_signed=False)
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def attention(q, k, v, config: TuningConfig):

def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
q_head = lax.axis_index("heads")
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
wg_idx = lax.axis_index("wg")
qo_smem2, k_smem, v_smem = smem_buffers
Expand All @@ -85,7 +86,6 @@ def _compute_wg():
plgpu.set_max_registers(232, action="increase")
qo_smem = qo_smem2.at[wg_idx]
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")

plgpu.copy_gmem_to_smem(
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
Expand Down Expand Up @@ -175,7 +175,7 @@ def _wait():
@pl.when(wg_idx == 2)
def _memory_wg():
plgpu.set_max_registers(40, action="decrease")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
for i in range(max_concurrent_steps):
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
Expand Down Expand Up @@ -268,11 +268,11 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig):

def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
wg_idx = lax.axis_index("wg")
qo_smem2, q_barriers, schedule_barrier = scoped
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))

def perform_schedule_barrier():
if config.use_schedule_barrier:
Expand Down
8 changes: 6 additions & 2 deletions tests/pallas/mgpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,20 @@ def setUp(self):
(6, 3), # GQA
(4, 4),), # MHA
head_dim=(64, 128, 256),
attention_impl=(
attention_mgpu.attention,
attention_mgpu.attention_with_pipeline_emitter,
),
)
def test_flash_attention(
self, batch_size, q_seq_len, kv_seq_len, num_q_and_kv_heads, head_dim
self, batch_size, q_seq_len, kv_seq_len, num_q_and_kv_heads, head_dim, attention_impl
):
num_q_heads, num_kv_heads = num_q_and_kv_heads
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
out = attention_mgpu.attention(
out = attention_impl(
q, k, v, attention_mgpu.TuningConfig(block_q=64, block_kv=64, max_concurrent_steps=2)
)
out_ref = attention_mgpu.attention_reference(q, k, v)
Expand Down
36 changes: 27 additions & 9 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,30 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)

@parameterized.named_parameters(
{"testcase_name": "1d_none",
"shape": (256,), "indexers": (slice(0, 128), slice(None, 32))},
{"testcase_name": "1d_offset",
"shape": (256,), "indexers": (slice(32, 96), slice(0, 32))},
{"testcase_name": "2d_extract",
"shape": (64, 64), "indexers": (4, slice(0, 64))},
)
{
"testcase_name": "1d_none",
"shape": (256,),
"indexers": (slice(0, 128), slice(None, 32)),
},
{
"testcase_name": "1d_offset",
"shape": (256,),
"indexers": (slice(32, 96), slice(0, 32)),
},
{
"testcase_name": "2d_extract_static",
"shape": (64, 64),
"indexers": (4, slice(0, 64)),
},
{
"testcase_name": "2d_extract_dyn",
"shape": (64, 64),
"indexers": lambda in_dev: (
pl.program_id(0) + 4 if in_dev else jnp.array(4),
slice(0, 64),
),
},
)
def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers):
@functools.partial(
pl.pallas_call,
Expand All @@ -353,10 +370,11 @@ def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers):
scratch_shapes=[plgpu.SMEM(shape, jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
grid=(1,),
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
scratch_ref_sliced = scratch_ref
for indexer in indexers:
for indexer in indexers(True) if callable(indexers) else indexers:
scratch_ref_sliced = scratch_ref_sliced.at[indexer]
x_ref_gmem = x_ref_gmem.at[indexer]
plgpu.copy_gmem_to_smem(
Expand All @@ -368,7 +386,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
x = jnp.arange(np.prod(shape)).astype(jnp.float32).reshape(*shape)
result = kernel(x)
ref = x + 1.0
for indexer in indexers:
for indexer in indexers(False) if callable(indexers) else indexers:
result = result[indexer]
ref = ref[indexer]
np.testing.assert_array_equal(result, ref)
Expand Down

0 comments on commit f2bf47f

Please sign in to comment.