Skip to content

Commit

Permalink
[Mosaic GPU] Add implementation of FA3 with pipeline emitter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716737962
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 24, 2025
1 parent c10b9b8 commit ff5cdb4
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 5 deletions.
1 change: 1 addition & 0 deletions jax/experimental/pallas/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import kernel as kernel
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
Expand Down
180 changes: 175 additions & 5 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TuningConfig:
block_q: int
block_kv: int
max_concurrent_steps: int
use_schedule_barrier: bool = True

def __post_init__(self):
if self.block_q % 64:
Expand Down Expand Up @@ -231,6 +232,166 @@ def entry(q_ref, k_ref, v_ref, out_ref):
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
)(q, k, v)

@functools.partial(jax.jit, static_argnames=["config"])
def attention_with_pipeline_emitter(q, k, v, config: TuningConfig):
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}")
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
_, kv_seq_len, num_kv_heads, _ = k.shape
kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim)
if k.shape != kv_shape:
raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)")
if k.shape != kv_shape:
raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)")
if (dtype := q.dtype) != k.dtype or dtype != v.dtype:
raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}")
if num_q_heads % num_kv_heads:
raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}")
q_heads_per_kv_head = num_q_heads // num_kv_heads
if head_dim % 64:
raise ValueError(f"{head_dim=} must be divisible by 64")
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")

max_concurrent_steps = min(
config.max_concurrent_steps, kv_seq_len // config.block_kv
)
compute_wgs = 2
block_q, block_kv = config.block_q, config.block_kv
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")

tiling = plgpu.TilingTransform((64, 64))
swizzle = plgpu.SwizzleTransform(128)
transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4))

def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
wg_idx = lax.axis_index("wg")
qo_smem2, q_barriers, schedule_barrier = scoped
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")

def perform_schedule_barrier():
if config.use_schedule_barrier:
plgpu.barrier_arrive(schedule_barrier)
plgpu.barrier_wait(schedule_barrier)

def _compute_thread():
qo_smem = qo_smem2.at[wg_idx]
m_i = plgpu.layout_cast(
jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW,
)
l_i = plgpu.layout_cast(
jnp.full((block_q,), 0, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW,
)
acc = plgpu.layout_cast(
jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
# Q is not pipelined, so we load in with a manual DMA.
plgpu.copy_gmem_to_smem(
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
qo_smem,
q_barriers.at[wg_idx],
)
plgpu.barrier_wait(q_barriers.at[wg_idx])
pl.when(wg_idx == 1)(perform_schedule_barrier)
final_carry = (yield (acc, m_i, l_i))
del m_i # Unused
pl.when(wg_idx == 0)(perform_schedule_barrier)
acc, _, l_i = final_carry
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
qo_smem[...] = acc.astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
)
plgpu.wait_smem_to_gmem(0)

def kv_pipeline(k_smem, v_smem,
k_consumed_barrier, v_consumed_barrier,
carry):
acc, m_i, l_i = carry
qo_smem = qo_smem2.at[wg_idx]
def compute_qk(acc_ref):
plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem, (1, 0)))
perform_schedule_barrier()
return acc_ref[...]
qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32))
plgpu.barrier_arrive(k_consumed_barrier)

# Softmax
# We keep m scaled by log2e to use FMA instructions when computing p.
log2e = math.log2(math.e)
m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e)
alpha = jnp.exp2(m_i - m_ij)
m_i = m_ij
p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0]))
acc *= lax.broadcast_in_dim(alpha, acc.shape, [0])
l_i *= alpha
p16 = p.astype(dtype)
perform_schedule_barrier()
l_i += p.sum(axis=1)
# PV
def compute_pv(acc_ref):
plgpu.wgmma(acc_ref, p16, v_smem)
acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc))
plgpu.barrier_arrive(v_consumed_barrier)
return acc, m_i, l_i
pipeline = plgpu.emit_pipeline_warp_specialized(
kv_pipeline,
grid=(kv_seq_len // block_kv,),
max_concurrent_steps=max_concurrent_steps,
num_compute_wgs=compute_wgs,
memory_registers=40,
wg_axis="wg",
manual_consumed_barriers=True,
carry_coroutine=_compute_thread,
in_specs=[
plgpu.GPUBlockSpec( # k
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, transpose, swizzle]),
plgpu.GPUBlockSpec( # v
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, swizzle]),
],
out_specs=[],
)
k_ref = k_ref.at[batch, :, kv_head, :]
v_ref = v_ref.at[batch, :, kv_head, :]
pipeline(k_ref, v_ref)
mesh = plgpu.GPUMesh(
grid=(batch_size, num_q_tiles, num_q_heads),
num_threads=3,
axis_names=("batch", "q_seq", "heads", "wg"),
)
def run(refs):
q_ref, k_ref, v_ref, out_ref = refs
@pl.core_map(mesh,
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
)
def _kernel_entry():
qo_scratch = plgpu.SMEM(
(compute_wgs, block_q, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
pl.run_scoped(
lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, args),
qo_scratch,
plgpu.Barrier(1, num_barriers=compute_wgs),
plgpu.Barrier(num_arrivals=compute_wgs),
)
@jax.jit
def run_function(q, k, v, o):
_, _, _, out = pl.run_state(run)((q, k, v, o))
return out
out = run_function(q, k, v, jnp.full_like(q, jnp.inf))
return out


@jax.jit
def attention_reference(q, k, v):
Expand All @@ -251,21 +412,30 @@ def attention_reference(q, k, v):
def main(unused_argv):
num_q_heads = 16
num_kv_heads = 16
problem_it = itertools.product((1,), (4096, 32768,), (64, 128, 256,))
for batch_size, seq_len, head_dim in problem_it:
use_pipeline_emitter = False
if use_pipeline_emitter:
attention_impl = attention_with_pipeline_emitter
schedule_barrier_opts = (True, False)
else:
attention_impl = attention
schedule_barrier_opts = (True,)

problem_it = itertools.product(
(1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts)
for batch_size, seq_len, head_dim, use_schedule_barrier in problem_it:
q_seq_len = kv_seq_len = seq_len
print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}"
f"{num_q_heads=:<4} {head_dim=:<6} ====")
f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} ====")
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
block_q = 64
best = None
for block_kv in (256, 128, 64):
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier)
try:
out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v)
out, runtime_ms = profiler.measure(functools.partial(attention_impl, config=config))(q, k, v)
if seq_len < 32768:
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
Expand Down

0 comments on commit ff5cdb4

Please sign in to comment.