Skip to content

Commit

Permalink
[Mosaic GPU] Add basic support for TMA with sub-byte types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717952238
  • Loading branch information
apaszke authored and Google-ML-Automation committed Jan 24, 2025
1 parent 313e35a commit 4ff59de
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 16 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/launch_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def init_tma_desc(host_ptr):
args = [
host_ptr,
base_ptr,
c(utils.bytewidth(ref_ty.element_type), i64),
c(utils.bitwidth(ref_ty.element_type), i64),
c(rank, i64),
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
Expand Down
23 changes: 17 additions & 6 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,13 +1067,24 @@ def memref_ptr(memref_arg, memory_space=None):
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])

elem_bytewidth = bytewidth(memref_ty.element_type)
offset_elems = llvm.extractvalue(i64, desc, [2])
offset_bytes = llvm.mul(
offset_elems,
c(elem_bytewidth, i64),
overflow_flags=llvm.IntegerOverflowFlags.none,
)
elem_bitwidth = bitwidth(memref_ty.element_type)
if elem_bitwidth < 8:
*_, static_offset = memref_ty.get_strides_and_offset()
if static_offset == ir.ShapedType.get_dynamic_stride_or_offset():
raise NotImplementedError
assert elem_bitwidth.bit_count() == 1
packing = 8 // elem_bitwidth
if static_offset % packing != 0:
raise ValueError
offset_bytes = c(static_offset // packing, i64)
else:
assert elem_bitwidth % 8 == 0
offset_bytes = llvm.mul(
offset_elems,
c(elem_bitwidth // 8, i64),
overflow_flags=llvm.IntegerOverflowFlags.none,
)
return llvm.inttoptr(
ptr_ty,
llvm.add(
Expand Down
25 changes: 24 additions & 1 deletion jaxlib/mosaic/gpu/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cassert>
#include <cstdint>
#include <cstdio>

Expand All @@ -21,7 +22,7 @@ limitations under the License.
extern "C" {

void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
int64_t elem_bytewidth, int64_t rank,
int64_t elem_bitwidth, int64_t rank,
int64_t *sizes, int64_t *strides,
int64_t swizzle_bytes, int64_t *window_shape) {
if (((uintptr_t)tma_desc) % 64 != 0) {
Expand All @@ -31,6 +32,28 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
abort();
}

// Pack 4 bit types in 8 bit pairs.
int64_t elem_bytewidth;
if (elem_bitwidth < 8) {
// Check that it's a power of 2.
assert((elem_bitwidth & (elem_bitwidth - 1)) == 0);
int packing = 8 / elem_bitwidth;
assert(sizes[rank - 1] % packing == 0);
assert(window_shape[rank - 1] % packing == 0);
assert(strides[rank - 1] == 1);

// TMA requires that the last dimension be the contiguous one so we pack the
// elements under that assumption.
sizes[rank - 1] /= packing;
window_shape[rank - 1] /= packing;
for (int i = 0; i < rank - 1; i++) {
strides[i] /= packing;
}
elem_bytewidth = 1;
} else {
elem_bytewidth = elem_bitwidth / 8;
}

CUtensorMapDataType data_type;
if (elem_bytewidth == 1) {
data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
Expand Down
72 changes: 64 additions & 8 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,73 @@ def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
thread_id = arith.addi(thread_id, arith.muli(gpu.thread_id(dim), stride))
stride = arith.muli(stride, gpu.block_dim(dim))
is_first_thread = arith.cmpi(arith.CmpIPredicate.eq, thread_id, c(0, index))

src_ty = ir.MemRefType(src.type)
dst_ty = ir.MemRefType(dst.type)
if src_ty.shape != dst_ty.shape:
raise ValueError(
f"src and dst shapes don't match: {src_ty.shape} != {dst_ty.shape}"
)
shape = src_ty.shape
dyn_strides = [c(s, index) for s in get_contiguous_strides(shape)]
if src_ty.element_type != dst_ty.element_type:
raise ValueError(
f"src and dst element types don't match: {src_ty.element_type} !="
f" {dst_ty.element_type}"
)
contig_strides = get_contiguous_strides(shape)
# If swizzling is on, at least one of the memrefs must be contiguous
# (simulating a TMA).
if (swizzle is not None and
src_ty.get_strides_and_offset()[0] != contig_strides and
dst_ty.get_strides_and_offset()[0] != contig_strides):
raise NotImplementedError(src_ty, dst_ty)

bw = bitwidth(src_ty.element_type)
if bw < 8:
assert bw.bit_count() == 1
packing = 8 // bw
if shape[-1] % packing:
raise NotImplementedError
workgroup_mem = ir.Attribute.parse("#gpu.address_space<workgroup>")
shape = (*shape[:-1], shape[-1] // packing)
contig_strides = get_contiguous_strides(shape)
def bitcast(ref):
ref_ty = ir.MemRefType(ref.type)
old_strides = ref_ty.get_strides_and_offset()[0]
if old_strides[-1] != 1:
raise NotImplementedError
new_strides = [s // packing for s in old_strides[:-1]] + [1]
new_ref_ty = ir.MemRefType.get(
shape,
ir.VectorType.get((packing,), src_ty.element_type), # noqa: F821
ir.StridedLayoutAttr.get(0, new_strides),
ref_ty.memory_space,
)
ptr_space = (
3
if ref_ty.memory_space is not None
and ref_ty.memory_space == workgroup_mem
else None
)
return ptr_as_memref(
# NOTE: memref_ptr applies the offset in case there was any.
memref_ptr(ref, memory_space=ptr_space),
new_ref_ty,
ptr_memory_space=ptr_space,
)
src = bitcast(src)
dst = bitcast(dst)
bw = 8
del src_ty, dst_ty # If you remove this, update it in the branch above
dyn_strides = [c(s, index) for s in contig_strides]

with ir.InsertionPoint(scf.IfOp(is_first_thread).then_block):
def body(*idx):
dst_idx = idx
if swizzle is not None:
assert swizzle.bit_count() == 1
bytes_per_element = bytewidth(src_ty.element_type)
assert bw % 8 == 0
bytes_per_element = bw // 8
linear_idx = c(0, index)
for stride, i in zip(dyn_strides, idx):
linear_idx = arith.addi(linear_idx, arith.muli(i, stride))
Expand Down Expand Up @@ -963,10 +1016,11 @@ class TMATest(TestCase):
@parameterized.product(
swizzle=(None, 32, 64, 128),
shape=((64, None), (5, None), (2, 3, 5, None)),
dtype=(jnp.float16, jnp.float32),
dtype=(jnp.float32, jnp.float16, jnp.int4),
)
def test_tma_load_basic(self, swizzle, shape, dtype):
minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize
bw = bitwidth(dtype_to_ir_type(dtype))
minor_size = 64 if swizzle is None else 8 * swizzle // bw
shape = (*shape[:-1], minor_size)
i1 = ir.IntegerType.get_signless(1)
def kernel(ctx, src, dst, smem):
Expand Down Expand Up @@ -1044,12 +1098,14 @@ def kernel(ctx, src, dst, scratch):
idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))
)
stride *= cluster[d]
slc = ds(
arith.muli(idx, c(16, index)), 16
idx_minor = arith.divui(idx, c(2, index))
idx_major = arith.remui(idx, c(2, index))
slc_minor = ds(
arith.muli(idx_minor, c(16 * 2, index)), 16 * 2
)
copy(
memref_slice(tmp, (slice(None), slc)),
memref_slice(dst, (noncollective_idx, slice(None), slc)),
memref_slice(tmp, (idx_major, slc_minor)),
memref_slice(dst, (noncollective_idx, idx_major, slc_minor)),
swizzle=swizzle,
)
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
Expand Down

0 comments on commit 4ff59de

Please sign in to comment.