Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic GPU] Add implementation of FA3 with pipeline emitter. #26083

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/experimental/pallas/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import kernel as kernel
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
Expand Down
180 changes: 175 additions & 5 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TuningConfig:
block_q: int
block_kv: int
max_concurrent_steps: int
use_schedule_barrier: bool = True

def __post_init__(self):
if self.block_q % 64:
Expand Down Expand Up @@ -231,6 +232,166 @@ def entry(q_ref, k_ref, v_ref, out_ref):
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
)(q, k, v)

@functools.partial(jax.jit, static_argnames=["config"])
def attention_with_pipeline_emitter(q, k, v, config: TuningConfig):
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}")
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
_, kv_seq_len, num_kv_heads, _ = k.shape
kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim)
if k.shape != kv_shape:
raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)")
if k.shape != kv_shape:
raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)")
if (dtype := q.dtype) != k.dtype or dtype != v.dtype:
raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}")
if num_q_heads % num_kv_heads:
raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}")
q_heads_per_kv_head = num_q_heads // num_kv_heads
if head_dim % 64:
raise ValueError(f"{head_dim=} must be divisible by 64")
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")

max_concurrent_steps = min(
config.max_concurrent_steps, kv_seq_len // config.block_kv
)
compute_wgs = 2
block_q, block_kv = config.block_q, config.block_kv
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")

tiling = plgpu.TilingTransform((64, 64))
swizzle = plgpu.SwizzleTransform(128)
transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4))

def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
wg_idx = lax.axis_index("wg")
qo_smem2, q_barriers, schedule_barrier = scoped
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")

def perform_schedule_barrier():
if config.use_schedule_barrier:
plgpu.barrier_arrive(schedule_barrier)
plgpu.barrier_wait(schedule_barrier)

def _compute_thread():
qo_smem = qo_smem2.at[wg_idx]
m_i = plgpu.layout_cast(
jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW,
)
l_i = plgpu.layout_cast(
jnp.full((block_q,), 0, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW,
)
acc = plgpu.layout_cast(
jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
# Q is not pipelined, so we load in with a manual DMA.
plgpu.copy_gmem_to_smem(
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
qo_smem,
q_barriers.at[wg_idx],
)
plgpu.barrier_wait(q_barriers.at[wg_idx])
pl.when(wg_idx == 1)(perform_schedule_barrier)
final_carry = (yield (acc, m_i, l_i))
del m_i # Unused
pl.when(wg_idx == 0)(perform_schedule_barrier)
acc, _, l_i = final_carry
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
qo_smem[...] = acc.astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
)
plgpu.wait_smem_to_gmem(0)

def kv_pipeline(k_smem, v_smem,
k_consumed_barrier, v_consumed_barrier,
carry):
acc, m_i, l_i = carry
qo_smem = qo_smem2.at[wg_idx]
def compute_qk(acc_ref):
plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem, (1, 0)))
perform_schedule_barrier()
return acc_ref[...]
qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32))
plgpu.barrier_arrive(k_consumed_barrier)

# Softmax
# We keep m scaled by log2e to use FMA instructions when computing p.
log2e = math.log2(math.e)
m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e)
alpha = jnp.exp2(m_i - m_ij)
m_i = m_ij
p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0]))
acc *= lax.broadcast_in_dim(alpha, acc.shape, [0])
l_i *= alpha
p16 = p.astype(dtype)
perform_schedule_barrier()
l_i += p.sum(axis=1)
# PV
def compute_pv(acc_ref):
plgpu.wgmma(acc_ref, p16, v_smem)
acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc))
plgpu.barrier_arrive(v_consumed_barrier)
return acc, m_i, l_i
pipeline = plgpu.emit_pipeline_warp_specialized(
kv_pipeline,
grid=(kv_seq_len // block_kv,),
max_concurrent_steps=max_concurrent_steps,
num_compute_wgs=compute_wgs,
memory_registers=40,
wg_axis="wg",
manual_consumed_barriers=True,
carry_coroutine=_compute_thread,
in_specs=[
plgpu.GPUBlockSpec( # k
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, transpose, swizzle]),
plgpu.GPUBlockSpec( # v
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, swizzle]),
],
out_specs=[],
)
k_ref = k_ref.at[batch, :, kv_head, :]
v_ref = v_ref.at[batch, :, kv_head, :]
pipeline(k_ref, v_ref)
mesh = plgpu.GPUMesh(
grid=(batch_size, num_q_tiles, num_q_heads),
num_threads=3,
axis_names=("batch", "q_seq", "heads", "wg"),
)
def run(refs):
q_ref, k_ref, v_ref, out_ref = refs
@pl.core_map(mesh,
compiler_params=plgpu.GPUCompilerParams(approx_math=True),
)
def _kernel_entry():
qo_scratch = plgpu.SMEM(
(compute_wgs, block_q, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
pl.run_scoped(
lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, args),
qo_scratch,
plgpu.Barrier(1, num_barriers=compute_wgs),
plgpu.Barrier(num_arrivals=compute_wgs),
)
@jax.jit
def run_function(q, k, v, o):
_, _, _, out = pl.run_state(run)((q, k, v, o))
return out
out = run_function(q, k, v, jnp.full_like(q, jnp.inf))
return out


@jax.jit
def attention_reference(q, k, v):
Expand All @@ -251,21 +412,30 @@ def attention_reference(q, k, v):
def main(unused_argv):
num_q_heads = 16
num_kv_heads = 16
problem_it = itertools.product((1,), (4096, 32768,), (64, 128, 256,))
for batch_size, seq_len, head_dim in problem_it:
use_pipeline_emitter = False
if use_pipeline_emitter:
attention_impl = attention_with_pipeline_emitter
schedule_barrier_opts = (True, False)
else:
attention_impl = attention
schedule_barrier_opts = (True,)

problem_it = itertools.product(
(1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts)
for batch_size, seq_len, head_dim, use_schedule_barrier in problem_it:
q_seq_len = kv_seq_len = seq_len
print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}"
f"{num_q_heads=:<4} {head_dim=:<6} ====")
f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} ====")
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
block_q = 64
best = None
for block_kv in (256, 128, 64):
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier)
try:
out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v)
out, runtime_ms = profiler.measure(functools.partial(attention_impl, config=config))(q, k, v)
if seq_len < 32768:
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
Expand Down
Loading