From 4693b1cc9e102dd18a093ffff8275b2c2f4da389 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 31 Jan 2025 16:20:54 -0500 Subject: [PATCH] [AMDGPU] Use shared memory in multi_mma ukernel (#19786) This achieves about 210 Top/s on CPX-mode MI300X, about 64% of peak 327 Top/s. That's about parity with the non-ukernel codegen path, which also uses shared memory. An earlier revision of this PR was opting out of DistributeMmaToLanes, which was more natural since a kernel that uses shared memory has to perform workgroup-relative indexing in the copies from global to shared memory. That required fine ordering of the pass pipeline, and ended up performing worse, at 180 Top/s vs 210 Top/s. So this PR instead stays on DistributeMmaToLanes, and then adds the negative thread-relative offsets to compensate. This relies on interpreting bitcode to tell exactly how much shared memory to allocate. That takes 2 ms. To avoid doing it redundantly, this is cached, with the `DataTiledMMAAttr` value as key, so this should only run a few times per iree-compile invocation. When it is determined that no shared memory should be allocated, to avoid creating 0-sized tensors, a new `iree_codegen.null_pointer` type is introduced to be passed in lieu of an actual tensor. It lowers to a null pointer (and offset). It is intended to be used with ukernels taking a tensor/memref/pointer argument that is nullable, such as the shared memory argument here. --------- Signed-off-by: Benoit Jacob --- .../target/ROCM/builtins/ukernel/common.h | 21 +- ...uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c | 170 +++++++++-- .../test/config_ukernel_multi_mma_gfx942.mlir | 1 + .../Codegen/Common/GPU/GPULowerToUKernels.cpp | 35 ++- .../compiler/Codegen/Common/GPU/Passes.td | 2 + .../GPU/test/gpu_lower_to_ukernels.mlir | 30 +- .../Common/test/lower_ukernel_to_calls.mlir | 13 + .../Codegen/Dialect/Codegen/IR/BUILD.bazel | 27 ++ .../Codegen/Dialect/Codegen/IR/CMakeLists.txt | 15 + .../Dialect/Codegen/IR/IREECodegenDialect.cpp | 8 +- .../Dialect/Codegen/IR/IREECodegenDialect.td | 1 + .../Dialect/Codegen/IR/IREECodegenOps.h | 1 + .../Dialect/Codegen/IR/IREECodegenOps.td | 11 +- .../Dialect/Codegen/IR/IREECodegenTypes.cpp | 15 + .../Dialect/Codegen/IR/IREECodegenTypes.h | 5 + .../Dialect/Codegen/IR/IREECodegenTypes.td | 21 ++ .../Codegen/Dialect/Codegen/IR/UKernelOps.cpp | 10 + .../Codegen/Dialect/Codegen/IR/UKernelOps.h | 1 + .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 5 +- .../Codegen/LLVMGPU/ConvertToLLVM.cpp | 23 +- .../Codegen/LLVMGPU/Utils/BUILD.bazel | 3 + .../Codegen/LLVMGPU/Utils/CMakeLists.txt | 3 + .../LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp | 275 +++++++++++++++--- .../LLVMGPU/test/convert_to_rocdl.mlir | 13 + 24 files changed, 608 insertions(+), 101 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.td diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/common.h b/compiler/plugins/target/ROCM/builtins/ukernel/common.h index 3113643ca1d1..0d685eba54d7 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/common.h +++ b/compiler/plugins/target/ROCM/builtins/ukernel/common.h @@ -81,29 +81,14 @@ _Float16 __ockl_wfred_max_f16(_Float16); int64_t __ockl_wfred_min_i64(int64_t); int32_t __ockl_wfred_min_i32(int32_t); -#define __CLK_LOCAL_MEM_FENCE 0x01 -typedef unsigned __cl_mem_fence_flags; - static inline void __threadfence_block() { __builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "workgroup"); } -static inline void __work_group_barrier(__cl_mem_fence_flags flags) { - if (flags) { - __builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup"); - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup"); - } else { - __builtin_amdgcn_s_barrier(); - } -} - -static inline void __barrier(int n) { - __work_group_barrier((__cl_mem_fence_flags)n); -} - [[clang::convergent]] static inline void __syncthreads() { - __barrier(__CLK_LOCAL_MEM_FENCE); + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup"); } //===----------------------------------------------------------------------===// diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c index c329270ff3e0..30a5e94250aa 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c @@ -6,46 +6,176 @@ #include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" -// Very naive kernel. TODO(bjacob): -// 1. Shared memory: can't allocate it within the microkernel (which is just a -// helper device function, not the actual amdgpu_kernel). Need to get it -// passed down here as additional parameters. -// 2. Better scheduling via either barrier intrinsics or inline assemby. -[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8( - const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer, - int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int32_t k_size, +// Encodes some information about a A/B fragment tile; +typedef struct ab_tile_info_t { + // Terminology: "_vecs": + // We will be counting in units of "vectors", meaning, for each A/B fragment + // the corresponding operand type of this particular MFMA intrinsic. + // For A and B, that type is i64, used as <8 x i8>. + + // Number of vectors in the tile. + int num_vecs; + // Number of vectors that we store in shared memory. That is typically equal + // to num_vecs if using shared memory for the tile, or 0 otherwise. + int num_shared_vecs; +} ab_tile_info_t; + +static ab_tile_info_t get_ab_tile_info(int tile_intrinsics, int tile_subgroups, + int opposite_tile_subgroups) { + ab_tile_info_t info; + info.num_vecs = /*subgroup size*/ 64 * tile_intrinsics * tile_subgroups; + // Use shared memory if the opposite tile has more than 1 subgroup, so that + // using shared memory would amortize loads from global memory. + info.num_shared_vecs = opposite_tile_subgroups > 1 ? info.num_vecs : 0; + return info; +} + +static int32_t get_shared_memory_bytes(ab_tile_info_t a_tile, + ab_tile_info_t b_tile) { + // For this MFMA intrinsic, the A and B vector types are 8 bytes. + return 8 * (a_tile.num_shared_vecs + b_tile.num_shared_vecs); +} + +// The bitcode of this function is interpreted during IREE compilation to +// determine the exact shared_memory_bytes to pass to the ukernel. +int32_t iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_query_shared_memory_bytes( int32_t intrinsics_m, int32_t subgroups_m, int32_t intrinsics_n, int32_t subgroups_n, int32_t intrinsics_k) { - // Load existing accumulators. The VLA becomes a normal array after inlining. - int32x4_t c[intrinsics_m][intrinsics_n]; - int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset); + ab_tile_info_t a_tile = + get_ab_tile_info(intrinsics_m * intrinsics_k, subgroups_m, subgroups_n); + ab_tile_info_t b_tile = + get_ab_tile_info(intrinsics_n * intrinsics_k, subgroups_n, subgroups_m); + return get_shared_memory_bytes(a_tile, b_tile); +} + +// Microkernel for iree_gpu.multi_mma with DataTiledMMAAttr with +// intrinsic = MFMA_I32_16x16x32_I8 and a shape with outer M and N dimensions +// equal to 1 (so that this is just doing the inner loop on the K dimension). +// +// This microkernel uses a shared memory workspace buffer provided by the +// caller. It is used to copy tiles of the A and/or B matrices, depending on +// which ones are reused by multiple subgroups. +// +// Note that the A, B, C matrix pointers are all after thread-distribution. +// When the pointer before thread-distribution is needed (when copying data +// into shared memory), care is taken to subtract the thread-relative offset, +// which is computed from the thread id. +// +// As this function is always_inline, some of its parameters are actually +// constant values after inlining, so some for() loops and if() branches here +// are actually unrolled/resolved at compile time, making this microkernel +// a generic "template". This is summarized in the below table. +// +// Parameters | Constant? | Description +// --------------------------- | ---------- | ----------- +// a_base, a_offset | No | A-matrix pointer (thread-distrib.) +// b_base, b_offset | No | B-matrix pointer (thread-distrib.) +// c_base, c_offset | No | C-matrix pointer (thread-distrib.) +// shared_memory_{base,offset} | No | Shared memory workspace pointer +// shared_memory_bytes | Yes | Shared memory workspace size +// k_size | From shape | Size of outer K dimension +// intrinsics_m, subgroups_m | Yes | See DataTiledMMAAttr +// intrinsics_n, subgroups_n | Yes | See DataTiledMMAAttr +// intrinsics_k | Yes | See DataTiledMMAAttr +// +// TODO(bjacob): Better scheduling via either barrier intrinsics or inline asm. +[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8( + const int8_t *a_base, int64_t a_offset, const int8_t *b_base, + int64_t b_offset, int32_t *c_base, int64_t c_offset, + int8_t *shared_memory_base, int64_t shared_memory_offset, + int32_t shared_memory_bytes, int32_t k_size, int32_t intrinsics_m, + int32_t subgroups_m, int32_t intrinsics_n, int32_t subgroups_n, + int32_t intrinsics_k) { + ab_tile_info_t a_tile = + get_ab_tile_info(intrinsics_m * intrinsics_k, subgroups_m, subgroups_n); + ab_tile_info_t b_tile = + get_ab_tile_info(intrinsics_n * intrinsics_k, subgroups_n, subgroups_m); + + // shared_memory_bytes should match exactly, as the value ultimately comes + // from the ..._query_shared_memory_bytes function defined just above. + if (shared_memory_bytes != get_shared_memory_bytes(a_tile, b_tile)) { + __builtin_trap(); + } + // Set up our pointers to shared memory for A and B tiles. + int64_t *restrict a_shared = + (int64_t *)(shared_memory_base + shared_memory_offset); + int64_t *restrict b_shared = a_shared + a_tile.num_shared_vecs; + + // Determine our thread id and the range for it. + int tid = __builtin_amdgcn_workitem_id_x(); + int numthreads = 64 * subgroups_m * subgroups_n; + __builtin_assume(tid < numthreads); + + // Compute the thread-relative data offsets. + int lane_id = tid % 64; + int subgroup_id = tid / 64; + int subgroup_n_idx = subgroup_id % subgroups_n; + int subgroup_m_idx = subgroup_id / subgroups_n; + int a_thread_relative_offset = + intrinsics_k * (lane_id + 64 * intrinsics_m * subgroup_m_idx); + int b_thread_relative_offset = + intrinsics_k * (lane_id + 64 * intrinsics_n * subgroup_n_idx); + + // Set up pointers to global memory. + const int64_t *restrict a_global = (const int64_t *)(a_base + a_offset); + const int64_t *restrict b_global = (const int64_t *)(b_base + b_offset); + int32x4_t *restrict c_global = ((int32x4_t *)(c_base + c_offset)); + + // Load existing accumulators from global memory into registers. + // The VLA becomes a normal array after inlining. + int32x4_t c_regs[intrinsics_m][intrinsics_n]; for (int m = 0; m < intrinsics_m; ++m) { for (int n = 0; n < intrinsics_n; ++n) { - c[m][n] = c_global[64 * (m * intrinsics_n + n)]; + c_regs[m][n] = c_global[64 * (m * intrinsics_n + n)]; } } // Arithmetic loop. - const int64_t *a_global = (const int64_t *)(a_buffer + a_offset); - const int64_t *b_global = (const int64_t *)(b_buffer + b_offset); for (int k_outer = 0; k_outer < k_size; ++k_outer) { + // Pointers to A/B data to feed MFMA, based on whether shared memory is + // used. + const int64_t *restrict a_mfma_vecs = + a_tile.num_shared_vecs ? a_shared + a_thread_relative_offset : a_global; + const int64_t *restrict b_mfma_vecs = + b_tile.num_shared_vecs ? b_shared + b_thread_relative_offset : b_global; + + // If needed, load data from global to shared memory. + if (tid < a_tile.num_shared_vecs) { // Benefits from above __builtin_assume. + for (int i = 0; i < a_tile.num_shared_vecs; i += numthreads) { + a_shared[i + tid] = a_global[i + tid - a_thread_relative_offset]; + } + } + if (tid < b_tile.num_shared_vecs) { // Benefits from above __builtin_assume. + for (int i = 0; i < b_tile.num_shared_vecs; i += numthreads) { + b_shared[i + tid] = b_global[i + tid - b_thread_relative_offset]; + } + } + // Thread barrier if any shared memory is used. + if (a_tile.num_shared_vecs || b_tile.num_shared_vecs) { + __syncthreads(); + } + // Load data from shared memory and perform arithmetic. for (int m = 0; m < intrinsics_m; ++m) { for (int n = 0; n < intrinsics_n; ++n) { for (int k = 0; k < intrinsics_k; ++k) { - c[m][n] = __builtin_amdgcn_mfma_i32_16x16x32_i8( - a_global[64 * intrinsics_k * m + k], - b_global[64 * intrinsics_k * n + k], c[m][n], 0, 0, 0); + c_regs[m][n] = __builtin_amdgcn_mfma_i32_16x16x32_i8( + a_mfma_vecs[64 * intrinsics_k * m + k], + b_mfma_vecs[64 * intrinsics_k * n + k], c_regs[m][n], 0, 0, 0); } } } - a_global += 64 * intrinsics_m * subgroups_m * intrinsics_k; - b_global += 64 * intrinsics_n * subgroups_n * intrinsics_k; + a_global += a_tile.num_vecs; + b_global += b_tile.num_vecs; + // Thread barrier if any shared memory is used. + if (a_tile.num_shared_vecs || b_tile.num_shared_vecs) { + __syncthreads(); + } } // Store accumulators. for (int m = 0; m < intrinsics_m; ++m) { for (int n = 0; n < intrinsics_n; ++n) { - c_global[64 * (m * intrinsics_n + n)] = c[m][n]; + c_global[64 * (m * intrinsics_n + n)] = c_regs[m][n]; } } } diff --git a/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir b/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir index 654dc54fa888..f10577026e72 100644 --- a/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir +++ b/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir @@ -27,3 +27,4 @@ func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x4x16x2x8xi8>, // CHECK-NOT: promote_operands // CHECK-SAME: reduction = [0, 0, 0] // CHECK-SAME: #iree_gpu.ukernel_config { } }; +static Value createSharedMemory(PatternRewriter &rewriter, Location loc, + int sharedMemoryBytes) { + RankedTensorType tensorType = + RankedTensorType::get({sharedMemoryBytes}, rewriter.getI8Type()); + ValueRange dynSizes{}; + if (!sharedMemoryBytes) { + IREE::Codegen::NullPointerType nullPointerType = + IREE::Codegen::NullPointerType::get(rewriter.getContext()); + return rewriter.create(loc, nullPointerType); + } + auto allocOp = + rewriter.create(loc, tensorType, dynSizes); + Attribute sharedAddrSpace = gpu::AddressSpaceAttr::get( + rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + allocOp.setMemorySpaceAttr(sharedAddrSpace); + return allocOp; +} + struct LowerMultiMmaToUKernelPattern : OpRewritePattern { LowerMultiMmaToUKernelPattern(MLIRContext *context) : OpRewritePattern(context) {} @@ -100,14 +121,16 @@ struct LowerMultiMmaToUKernelPattern : OpRewritePattern { if (!mma) { return rewriter.notifyMatchFailure(op, "unhandled MMAInterfaceAttr"); } + Location loc = op->getLoc(); + Type I32Type = rewriter.getI32Type(); auto castIndexToI32 = [&](Value val) { - return rewriter.create(op.getLoc(), - rewriter.getI32Type(), val); + return rewriter.create(loc, I32Type, val); }; auto constI32 = [&](int val) { - return rewriter.create(op.getLoc(), val, - rewriter.getI32Type()); + return rewriter.create(loc, val, I32Type); }; + int64_t sharedMemoryBytes = ukernelAttr.getSharedMemoryBytes(); + auto sharedMemory = createSharedMemory(rewriter, loc, sharedMemoryBytes); Value k = castIndexToI32( rewriter.create(op.getLoc(), op.getLhs(), 1)); Value intrinsicsM = constI32(mma.getIntrinsicsM()); @@ -118,8 +141,8 @@ struct LowerMultiMmaToUKernelPattern : OpRewritePattern { rewriter.replaceOpWithNewOp( op, TypeRange{op.getAccType()}, ukernelAttr.getName(), ValueRange{op.getLhs(), op.getRhs()}, op.getAcc(), - ValueRange{k, intrinsicsM, subgroupsM, intrinsicsN, subgroupsN, - intrinsicsK}, + ValueRange{sharedMemory, constI32(sharedMemoryBytes), k, intrinsicsM, + subgroupsM, intrinsicsN, subgroupsN, intrinsicsK}, ukernelAttr.getDefAttrs(), /*strided_outer_dims=*/rewriter.getIndexAttr(0)); return success(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 340fa65f3969..a9c5eb76204a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -120,6 +120,8 @@ def GPULowerToUKernelsPass : "::mlir::iree_compiler::IREE::Codegen::IREECodegenDialect", "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", "::mlir::arith::ArithDialect", + "::mlir::bufferization::BufferizationDialect", + "::mlir::gpu::GPUDialect", "::mlir::tensor::TensorDialect", ]; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir index 2095393d5709..5d143ac2aec6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir @@ -76,17 +76,33 @@ func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x1x1x2x8xi8>, %b : te kind = #iree_gpu.data_tiled_mma_layout, lowering_config = #iree_gpu.lowering_config<{ reduction = [0, 0, 0], - ukernel = #iree_gpu.ukernel_config, + ukernel = #iree_gpu.ukernel_config, workgroup = [1, 1, 0]}> } : tensor<1x2x8x1x1x2x8xi8>, tensor<1x2x1x2x1x1x2x8xi8> into tensor<1x1x1x8x2x1x1x4xi32> return %d : tensor<1x1x1x8x2x1x1x4xi32> } // CHECK-LABEL: func @multi_mma_mfma_i32_16x16x32_i8( -// CHECK-DAG: %c2_i32 = arith.constant 2 : i32 -// CHECK-DAG: %c8_i32 = arith.constant 8 : i32 -// CHECK-DAG: %c1_i32 = arith.constant 1 : i32 -// CHECK-DAG: %c4_i32 = arith.constant 4 : i32 -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic +// CHECK: bufferization.alloc_tensor() {memory_space = #gpu.address_space} : tensor<16384xi8> +// CHECK: iree_codegen.ukernel.generic +// CHECK-SAME: "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8" + +// ----- + +func.func @multi_mma_mfma_i32_16x16x32_i8_one_subgroup_no_shared_memory(%a : tensor<1x2x8x1x1x2x8xi8>, %b : tensor<1x2x1x2x1x1x2x8xi8>, %c : tensor<1x1x1x8x2x1x1x4xi32>) -> tensor<1x1x1x8x2x1x1x4xi32> { + %d = iree_gpu.multi_mma %a, %b, %c { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type], + kind = #iree_gpu.data_tiled_mma_layout, + lowering_config = #iree_gpu.lowering_config<{ + reduction = [0, 0, 0], + ukernel = #iree_gpu.ukernel_config, + workgroup = [1, 1, 0]}> + } : tensor<1x2x8x1x1x2x8xi8>, tensor<1x2x1x2x1x1x2x8xi8> into tensor<1x1x1x8x2x1x1x4xi32> + return %d : tensor<1x1x1x8x2x1x1x4xi32> +} + +// CHECK-LABEL: func @multi_mma_mfma_i32_16x16x32_i8_one_subgroup_no_shared_memory( +// CHECK: iree_codegen.null_pointer +// CHECK: iree_codegen.ukernel.generic // CHECK-SAME: "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8" -// CHECK-SAME: (%c2_i32, %c8_i32, %c1_i32, %c2_i32, %c4_i32, %c2_i32 : i32, i32, i32, i32, i32, i32) diff --git a/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir b/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir index 7dd5cf9a0f9e..8b9e5abcb47b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_to_calls.mlir @@ -176,3 +176,16 @@ func.func @ukernel_generic_test_fndef_attrs(%arg0 : memref, index, index) // CHECK-SAME: hal.import.fields = ["processor_id", "processor_data"] + +// ----- + +func.func @ukernel_with_null_pointer_arg() { + %0 = iree_codegen.null_pointer + iree_codegen.ukernel.generic "foo" ins(%0: !iree_codegen.null_pointer) + return +} + +// CHECK-LABEL: func.func private @foo(!iree_codegen.null_pointer, index) +// CHECK-DAG: %[[NULLPTR:.+]] = iree_codegen.null_pointer +// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 : index +// CHECK: call @foo(%[[NULLPTR]], %[[ZERO]]) : (!iree_codegen.null_pointer, index) -> () diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel index ed860aeb5134..e0faa9e3722c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel @@ -19,6 +19,7 @@ exports_files([ "IREECodegenInterfaces.td", "IREECodegenOps.td", "UKernelOps.td", + "IREECodegenTypes.td", ]) iree_td_library( @@ -29,6 +30,7 @@ iree_td_library( "IREECodegenDialect.td", "IREECodegenInterfaces.td", "IREECodegenOps.td", + "IREECodegenTypes.td", "UKernelOps.td", ], include = ["*.td"], @@ -52,6 +54,7 @@ iree_compiler_cc_library( "IREECodegenInterfaces.cpp", "IREECodegenLibraryManager.cpp", "IREECodegenOps.cpp", + "IREECodegenTypes.cpp", "UKernelOps.cpp", ], hdrs = [ @@ -59,6 +62,7 @@ iree_compiler_cc_library( "IREECodegenDialect.h", "IREECodegenInterfaces.h", "IREECodegenOps.h", + "IREECodegenTypes.h", "UKernelOps.h", ], textual_hdrs = [ @@ -75,10 +79,13 @@ iree_compiler_cc_library( "LoweringConfigEnums.h.inc", "UKernelOps.cpp.inc", "UKernelOps.h.inc", + "IREECodegenTypes.cpp.inc", + "IREECodegenTypes.h.inc", ], deps = [ ":IREECodegenDialectGen", ":IREECodegenOpsGen", + ":IREECodegenTypesGen", ":LoweringConfigGen", ":LoweringConfigInterfaceGen", ":UKernelOpsGen", @@ -182,6 +189,26 @@ iree_gentbl_cc_library( deps = [":td_files"], ) +iree_gentbl_cc_library( + name = "IREECodegenTypesGen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "IREECodegenTypes.h.inc", + ), + ( + ["--gen-typedef-defs"], + "IREECodegenTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "IREECodegenTypes.td", + deps = [ + ":td_files", + "//compiler/src/iree/compiler/Codegen/Interfaces:td_files", + ], +) + iree_gentbl_cc_library( name = "UKernelOpsGen", tbl_outs = [ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt index 0e02fea71220..e1f90249e587 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt @@ -18,6 +18,7 @@ iree_cc_library( "IREECodegenDialect.h" "IREECodegenInterfaces.h" "IREECodegenOps.h" + "IREECodegenTypes.h" "UKernelOps.h" TEXTUAL_HDRS "IREECodegenAttrs.cpp.inc" @@ -28,7 +29,9 @@ iree_cc_library( "IREECodegenInterfaces.h.inc" "IREECodegenOps.cpp.inc" "IREECodegenOps.h.inc" + "IREECodegenTypes.cpp.inc" "IREECodegenTypes.h" + "IREECodegenTypes.h.inc" "LoweringConfigEnums.cpp.inc" "LoweringConfigEnums.h.inc" "UKernelOps.cpp.inc" @@ -39,10 +42,12 @@ iree_cc_library( "IREECodegenInterfaces.cpp" "IREECodegenLibraryManager.cpp" "IREECodegenOps.cpp" + "IREECodegenTypes.cpp" "UKernelOps.cpp" DEPS ::IREECodegenDialectGen ::IREECodegenOpsGen + ::IREECodegenTypesGen ::LoweringConfigGen ::LoweringConfigInterfaceGen ::UKernelOpsGen @@ -108,6 +113,16 @@ iree_tablegen_library( --gen-enum-defs LoweringConfigEnums.cpp.inc ) +iree_tablegen_library( + NAME + IREECodegenTypesGen + TD_FILE + "IREECodegenTypes.td" + OUTS + --gen-typedef-decls IREECodegenTypes.h.inc + --gen-typedef-defs IREECodegenTypes.cpp.inc +) + iree_tablegen_library( NAME UKernelOpsGen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp index 8b02d9dbf6fd..10596a20e646 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp @@ -7,12 +7,16 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" -#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp.inc" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/IR/DialectImplementation.h" +// clang-format off +#define GET_TYPEDEF_CLASSES +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp.inc" // IWYU pragma: export +// clang-format on + namespace mlir::iree_compiler::IREE::Codegen { struct IREECodegenDialectOpAsmInterface : public OpAsmDialectInterface { @@ -44,6 +48,8 @@ void IREECodegenDialect::initialize() { #define GET_OP_LIST #include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp.inc" >(); + + addTypes(); } LogicalResult diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td index a51ff09e552a..4b6633828ddc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td @@ -67,6 +67,7 @@ def IREECodegen_Dialect : Dialect { /// the module once and can reuse it across all invocations. std::mutex libraryMutex; }]; + let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; let hasOperationAttrVerify = 1; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h index af7a9b3e7b69..12c673eea246 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENOPS_H_ #define IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENOPS_H_ +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td index a8579858ca27..86fdf36fede8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td @@ -8,6 +8,7 @@ #define IREE_CODEGEN_DIALECT_IREECODEGENOPS include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td" +include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -101,7 +102,15 @@ def IREECodegen_ExtractStridedMetadataOp : Op { + let summary = "Returns a null_pointer value."; + let description = [{ + This is meant to be used only as arguments to microkernels. + }]; + let results = (outs NullPointer:$result); + let assemblyFormat = "attr-dict"; +} #endif // IREE_CODEGEN_DIALECT_IREECODEGENOPS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp new file mode 100644 index 000000000000..11c6ca8cafe6 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp @@ -0,0 +1,15 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/DialectImplementation.h" + +// clang-format off +#define GET_TYPEDEF_CLASSES +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp.inc" // IWYU pragma: export +// clang-format on diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h index c152aeefc0d3..13a393d33914 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h @@ -13,6 +13,11 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LLVM.h" +// clang-format off +#define GET_TYPEDEF_CLASSES +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h.inc" // IWYU pragma: export +// clang-format on + namespace mlir::iree_compiler::IREE::Codegen { //===----------------------------------------------------------------------===// // Layout Struct Types. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.td new file mode 100644 index 000000000000..93aae40f7d52 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.td @@ -0,0 +1,21 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_CODEGEN_DIALECT_IREECODEGEN_TYPES +#define IREE_CODEGEN_DIALECT_IREECODEGEN_TYPES + +include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td" + +def NullPointer + : TypeDef { + let summary = "Pseudo null-pointer type. Lowers to a null pointer."; + let description = [{ + This is meant to be used only as arguments to microkernels. + }]; + let mnemonic = "null_pointer"; +} + +#endif // IREE_CODEGEN_DIALECT_IREECODEGEN_TYPES diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp index 2b9ff94b5f7b..058b6d165cae 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp @@ -100,6 +100,11 @@ static LogicalResult getCallOpType(MLIRContext *context, indexType); return success(); }) + .Case([&](NullPointerType nullPointerType) { + callOperandTypes.push_back(nullPointerType); + callOperandTypes.push_back(IndexType::get(context)); + return success(); + }) .Default([&](Type t) { return failure(); }); } @@ -131,6 +136,11 @@ static LogicalResult lowerToCallOperands(Location loc, RewriterBase &rewriter, } return success(); }) + .Case([&](NullPointerType /*unused*/) { + callOperands.push_back(operand); + callOperands.push_back(rewriter.create(loc, 0)); + return success(); + }) .Default([](Type) { return failure(); }); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h index 24fb97ab9c1b..82aa9eceff86 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_DIALECT_UKERNELOPS_H_ #define IREE_COMPILER_CODEGEN_DIALECT_UKERNELOPS_H_ +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h" #include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index c897ce19bdd9..fc85ae1c41f8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -532,12 +532,13 @@ def IREEGPU_UKernelConfigAttr : let summary = "An attribute specifying a ukernel that an op can lower to."; let description = [{ An attribute that can be applied to any operation to specify that it has - been match with a ukernel that is a legal lowering for it. + been matched with a ukernel that is a legal lowering for it. }]; let assemblyFormat = "`<` struct(params) `>`"; let parameters = (ins "StringAttr":$name, - "DictionaryAttr":$def_attrs + "DictionaryAttr":$def_attrs, + DefaultValuedParameter<"int64_t", "0", "Size in bytes of shared memory workspace">:$shared_memory_bytes ); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index e29d993f2f30..620a8f4fa310 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" @@ -529,14 +530,30 @@ struct HALInterfaceWorkgroupOpsConverter final } }; +class ConvertNullPointerOp : public ConvertToLLVMPattern { +public: + ConvertNullPointerOp(MLIRContext *context, LLVMTypeConverter &converter) + : ConvertToLLVMPattern(IREE::Codegen::NullPointerOp::getOperationName(), + context, converter) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, LLVM::LLVMPointerType::get(getContext())); + return success(); + } +}; + } // namespace void populateLLVMConversionPatterns(MLIRContext *context, RewritePatternSet &patterns, LLVMTypeConverter &converter) { - patterns - .insert( - context, converter); + patterns.add(context, converter); + converter.addConversion([context](IREE::Codegen::NullPointerType type) { + return LLVM::LLVMPointerType::get(context); + }); } void populateScalarizeMathOps(RewritePatternSet &patterns) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel index 66bd982ffa89..7a345bd63d1a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel @@ -37,6 +37,9 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", "//compiler/src/iree/compiler/Utils", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Interpreter", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AffineDialect", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt index 98ee9404ff61..22487898dc06 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt @@ -21,6 +21,9 @@ iree_cc_library( "LLVMGPUUtils.cpp" "PrefetchSharedMemoryCopy.cpp" DEPS + LLVMExecutionEngine + LLVMIRReader + LLVMInterpreter LLVMSupport MLIRAMDGPUDialect MLIRAffineDialect diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp index 8d81cf78ec61..d1b4c71139d0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp @@ -9,6 +9,10 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Utils/EmbeddedDataDirectory.h" +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/ExecutionEngine/GenericValue.h" +#include "llvm/ExecutionEngine/Interpreter.h" // Performs registration. +#include "llvm/IRReader/IRReader.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/AsmState.h" @@ -58,24 +62,38 @@ static UKernelNameAndSuffix getUKernelNameAndSuffix(Operation *op) { return {}; } -// Returns the UKernelConfigAttr for any op. Returns {} if no ukernel. -static IREE::GPU::UKernelConfigAttr getUKernelConfig(Operation *op) { +static int64_t getSharedMemoryBytes(IREE::GPU::TargetAttr gpuTarget) { + if (!gpuTarget) { + return 0; + } + IREE::GPU::TargetWgpAttr wgp = gpuTarget.getWgp(); + if (!wgp) { + return 0; + } + return wgp.getMaxWorkgroupMemoryBytes(); +} + +// Returns an initial UKernelConfigAttr containing the ukernel name and +// def_attrs. Does not yet contain bitcode-dependent fields such as shared +// memory size. Returns {} if no ukernel. +static IREE::GPU::UKernelConfigAttr getInitialUKernelConfig(Operation *op) { MLIRContext *context = op->getContext(); auto [name, suffix] = getUKernelNameAndSuffix(op); if (name.empty()) { return {}; } - auto target = IREE::HAL::ExecutableTargetAttr::lookup(op); - if (!hasUkernel(target, name)) { + auto execTarget = IREE::HAL::ExecutableTargetAttr::lookup(op); + if (!hasUkernel(execTarget, name)) { return {}; } - if (isROCMBackend(target)) { + if (isROCMBackend(execTarget)) { auto nameAttr = StringAttr::get( context, llvm::formatv("iree_uk_amdgpu_{}_{}", name, suffix)); auto defsAttr = DictionaryAttr::get( context, {{StringAttr::get(context, "vm.import.module"), StringAttr::get(context, "rocm")}}); - return IREE::GPU::UKernelConfigAttr::get(context, nameAttr, defsAttr); + return IREE::GPU::UKernelConfigAttr::get(context, nameAttr, defsAttr, + /*shared_memory_bytes=*/0); } return {}; } @@ -92,21 +110,14 @@ static IREE::GPU::UKernelConfigAttr getUKernelConfig(Operation *op) { static IREE::HAL::ExecutableObjectAttr getUKernelBitcode(MLIRContext *context, IREE::HAL::ExecutableTargetAttr execTarget, - ArrayAttr sourceExecutableObjects, StringRef ukernelName) { - IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(execTarget); - if (!gpuTarget) { - return {}; - } - StringRef gpuArch = gpuTarget.getArch(); - std::string bitcodeFilename = llvm::formatv("{}.{}.bc", ukernelName, gpuArch); - + ArrayAttr sourceExecutableObjects, StringRef filename) { // Early-return if the source executable.objects already contain an object // with the expected file name. This happens with user-provided bitcode in the // source IR. if (sourceExecutableObjects) { for (Attribute a : sourceExecutableObjects) { if (auto object = dyn_cast(a)) { - if (object.getPath() == bitcodeFilename) { + if (object.getPath() == filename) { return object; } } @@ -116,9 +127,8 @@ getUKernelBitcode(MLIRContext *context, // No user-provided bitcode, so we search our embedded bitcode files in the // EmbeddedDataDirectory singleton. std::optional bitcode; - EmbeddedDataDirectory::withGlobal([&](EmbeddedDataDirectory &dir) { - bitcode = dir.getFile(bitcodeFilename); - }); + EmbeddedDataDirectory::withGlobal( + [&](EmbeddedDataDirectory &dir) { bitcode = dir.getFile(filename); }); if (!bitcode) { return {}; } @@ -127,18 +137,19 @@ getUKernelBitcode(MLIRContext *context, auto bitcodeDenseAttr = DenseI8ResourceElementsAttr::get( VectorType::get({static_cast(bitcode->size())}, IntegerType::get(context, 8)), - bitcodeFilename, std::move(blob)); + filename, std::move(blob)); return IREE::HAL::ExecutableObjectAttr::get( - context, StringAttr::get(context, bitcodeFilename), + context, StringAttr::get(context, filename), cast(bitcodeDenseAttr)); } +static constexpr char executableObjectsAttrName[] = "hal.executable.objects"; + // Walks parents ops from `op` to return the nearest hal.executable.objects // array attribute. If the parent hal.executable.variant is reached, its objects // attribute is returned. // Adapted from ExecutableTargetAttr::lookup. -static ArrayAttr lookUpExecutableObjects(Operation *op, - StringRef executableObjectsAttrName) { +static ArrayAttr lookUpExecutableObjects(Operation *op) { MLIRContext *context = op->getContext(); auto attrId = StringAttr::get(context, executableObjectsAttrName); while (op) { @@ -158,39 +169,217 @@ static ArrayAttr lookUpExecutableObjects(Operation *op, return {}; } +static std::string getBitcodeFilename(IREE::GPU::TargetAttr gpuTarget, + StringRef name) { + return llvm::formatv("{}.{}.bc", name, gpuTarget.getArch()); +} + +// Helper for getSharedMemoryBytes. Typical latency: 2 ms. +// Evaluates the shared memory size required by the multi_mma microkernel by +// interpreting a bitcode function with a specific name. +// On failure, an op warning is emitted and {} is returned. +static std::optional expensivelyEvaluateSharedMemoryBytes( + IREE::GPU::MultiMmaOp op, IREE::GPU::UKernelConfigAttr ukernelConfig, + IREE::HAL::ExecutableObjectAttr bitcodeObject, + IREE::GPU::TargetAttr gpuTarget) { + auto mma = dyn_cast(op.getKind()); + + auto bitcodeData = bitcodeObject.getData(); + std::string buffer; + buffer.resize(bitcodeData.getStorageSize()); + if (failed(bitcodeObject.getData().serializeToBuffer( + op->getLoc(), llvm::endianness::native, + ArrayRef{buffer.data(), buffer.size()}))) { + op.emitWarning("Failed to serialize bitcode."); + return {}; + } + llvm::LLVMContext llvmContext; + llvm::Expected> module = + llvm::getLazyBitcodeModule( + llvm::MemoryBufferRef{buffer, ukernelConfig.getName()}, llvmContext, + /*ShouldLazyLoadMetadata=*/true); + if (!module) { + op.emitWarning("Failed to parse bitcode module."); + return {}; + } + llvm::EngineBuilder builder(std::move(module.get())); + std::string builderError; + builder.setEngineKind(llvm::EngineKind::Interpreter) + .setErrorStr(&builderError); + std::unique_ptr interpreter{builder.create()}; + if (!interpreter) { + op.emitWarning("Failed to create the interpreter."); + return {}; + } + std::string queryFuncName = + llvm::formatv("{}_query_shared_memory_bytes", ukernelConfig.getName()); + llvm::Function *func = interpreter->FindFunctionNamed(queryFuncName); + if (!func) { + op.emitWarning(llvm::formatv( + "Bitcode does not contain a function named {}.", queryFuncName)); + return {}; + } + auto constI32 = [](int32_t val) { + llvm::GenericValue v; + v.IntVal = APInt(32, val); + return v; + }; + SmallVector args{ + constI32(mma.getIntrinsicsM()), constI32(mma.getSubgroupsM()), + constI32(mma.getIntrinsicsN()), constI32(mma.getSubgroupsN()), + constI32(mma.getIntrinsicsK())}; + if (func->arg_size() != args.size()) { + op.emitWarning( + llvm::formatv("Bitcode function {} takes {} arguments. Expected {}.", + queryFuncName, func->arg_size(), args.size())); + return {}; + } + llvm::GenericValue interpreterResult = interpreter->runFunction(func, args); + if (interpreter->hasError()) { + op.emitWarning(llvm::formatv("Error while interpreting bitcode: {}.", + interpreter->getErrorMessage())); + return {}; + } + int sharedMemoryBytes = interpreterResult.IntVal.getSExtValue(); + + // Reject a ukernel that would consume too much shared memory, which we need + // to save for other purposes. This threshold can always be adjusted but we + // default to a low threshold to get an early signal. + int maxSharedMemoryBytes = getSharedMemoryBytes(gpuTarget) / 4; + if (sharedMemoryBytes > maxSharedMemoryBytes) { + op.emitWarning(llvm::formatv("The shared memory size {} required by the " + "ukernel exceeds the maximum allowed size {}.", + sharedMemoryBytes, maxSharedMemoryBytes)); + return {}; + } + return sharedMemoryBytes; +} + +// Returns the shared memory size required by the multi_mma ukernel. +// On failure, an op warning is emitted and {} is returned. +// Uses a static cache to avoid calling expensivelyEvaluateSharedMemoryBytes +// more than once per DataTiledMMAAttr value. +static std::optional +getSharedMemoryBytes(IREE::GPU::MultiMmaOp op, + IREE::GPU::UKernelConfigAttr ukernelConfig, + IREE::HAL::ExecutableObjectAttr bitcodeObject, + IREE::GPU::TargetAttr gpuTarget) { + auto mma = dyn_cast(op.getKind()); + + // We use the stringification of the attributes, rather than the + // attributes themselves, as the key, to ensure it's self-contained and does + // not contain pointers to other objects, such as a `MLIRContext*`, which + // could go dangling. + std::string key = llvm::formatv("mma = {}, gpuTarget = {}", mma, gpuTarget); + + struct CacheEntry { + std::optional sharedMemoryBytes; + std::mutex mutex; + bool evaluated = false; + }; + + // The cache and the mutex guarding it. + // We store the CacheEntry's by pointers, so that we don't need to worry about + // entryPtr being invalidated. + static llvm::StringMap> cache; + static std::mutex cacheMutex; + + CacheEntry *entryPtr = nullptr; + + { + // Critical section on `cacheMutex`. This is the only place where we + // access `cache`. When we will later update a cache entry, that will be + // through `entryPtr`, independently of `cache`. + std::lock_guard lock(cacheMutex); + auto iter = cache.find(key); + if (iter != cache.end()) { + // Cache hit. Early return. + return iter->second->sharedMemoryBytes; + } + // Cache miss. Create a new cache entry and acquire its mutex. + entryPtr = + cache.insert({key, std::make_unique()}).first->second.get(); + entryPtr->mutex.lock(); + } + + // If the entry still isn't evaluated after we have acquired its mutex, + // perform the evaluation now. + if (!entryPtr->evaluated) { + entryPtr->sharedMemoryBytes = expensivelyEvaluateSharedMemoryBytes( + op, ukernelConfig, bitcodeObject, gpuTarget); + entryPtr->evaluated = true; + } + + entryPtr->mutex.unlock(); + return entryPtr->sharedMemoryBytes; +} + +// Returns the finalized UKernelConfigAttr to use for `op`, or {} if `op` should +// not use a ukernel. +static IREE::GPU::UKernelConfigAttr +finalizeConfig(IREE::GPU::MultiMmaOp op, + IREE::GPU::UKernelConfigAttr ukernelConfig, + IREE::HAL::ExecutableObjectAttr bitcodeObject, + IREE::GPU::TargetAttr gpuTarget) { + std::optional sharedMemoryBytes = + getSharedMemoryBytes(op, ukernelConfig, bitcodeObject, gpuTarget); + if (!sharedMemoryBytes) { + // Could not evaluate sharedMemoryBytes. Prevent the ukernel selection. + return {}; + } + return IREE::GPU::UKernelConfigAttr::get( + op->getContext(), ukernelConfig.getName(), ukernelConfig.getDefAttrs(), + *sharedMemoryBytes); +} + +// Returns the finalized UKernelConfigAttr to use for `op`, or {} if `op` should +// not use a ukernel. +static IREE::GPU::UKernelConfigAttr +finalizeConfig(Operation *op, IREE::GPU::UKernelConfigAttr ukernelConfig, + IREE::HAL::ExecutableObjectAttr bitcodeObject, + IREE::GPU::TargetAttr gpuTarget) { + if (auto multiMmaOp = dyn_cast(op)) { + return finalizeConfig(multiMmaOp, ukernelConfig, bitcodeObject, gpuTarget); + } + return ukernelConfig; +} + // Ensures that the op has ukernel bitcode as a hal.executable.object, stored // as a hal.executable.objects attribute on the op itself, ready to be hoisted -// by the HoistExecutableObjects pass. -// Returns failure if no bitcode was found for the configured ukernel. -static LogicalResult -ensureUKernelBitcode(Operation *op, - IREE::GPU::UKernelConfigAttr ukernelConfig) { - constexpr StringLiteral executableObjectsAttrName = "hal.executable.objects"; - auto target = IREE::HAL::ExecutableTargetAttr::lookup(op); - ArrayAttr sourceExecutableObjects = - lookUpExecutableObjects(op, executableObjectsAttrName); +// by the HoistExecutableObjects pass, and returns the finalized config attr +// with the remaining bitcode-dependent fields populated. +// Returns {} if no bitcode was found for the configured ukernel, of if an error +// occurred trying to infer bitcode-dependent config fields (which may require +// interpreting bitcode). +static IREE::GPU::UKernelConfigAttr ensureUKernelBitcodeAndFinalizeConfig( + Operation *op, IREE::GPU::UKernelConfigAttr ukernelConfig) { MLIRContext *context = op->getContext(); - IREE::HAL::ExecutableObjectAttr bitcodeObject = getUKernelBitcode( - context, target, sourceExecutableObjects, ukernelConfig.getName()); + if (!ukernelConfig) { + return {}; + } + auto target = IREE::HAL::ExecutableTargetAttr::lookup(op); + IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(target); + if (!gpuTarget) { + return {}; + } + std::string filename = getBitcodeFilename(gpuTarget, ukernelConfig.getName()); + + ArrayAttr sourceExecutableObjects = lookUpExecutableObjects(op); + IREE::HAL::ExecutableObjectAttr bitcodeObject = + getUKernelBitcode(context, target, sourceExecutableObjects, filename); if (!bitcodeObject) { - return failure(); + return {}; } op->setAttr(executableObjectsAttrName, ArrayAttr::get(context, bitcodeObject)); - return success(); + return finalizeConfig(op, ukernelConfig, bitcodeObject, gpuTarget); } } // namespace IREE::GPU::UKernelConfigAttr selectUKernel(Operation *op) { - IREE::GPU::UKernelConfigAttr ukernelConfig = getUKernelConfig(op); - if (!ukernelConfig) { - return {}; - } - if (failed(ensureUKernelBitcode(op, ukernelConfig))) { - return {}; - } - return ukernelConfig; + IREE::GPU::UKernelConfigAttr initialConfig = getInitialUKernelConfig(op); + return ensureUKernelBitcodeAndFinalizeConfig(op, initialConfig); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir index c506ce522d54..1f0a91ea6060 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir @@ -259,3 +259,16 @@ builtin.module { // CHECK: llvm.zext %[[arg3]] : i32 to i64 // CHECK: llvm.insertvalue %[[arg0]] // CHECK: llvm.insertvalue %[[arg2]] + +// ----- +// Test lowering of iree_codegen.null_pointer. +module { + func.func private @foo(!iree_codegen.null_pointer) + func.func @null_pointer() { + %0 = iree_codegen.null_pointer + call @foo(%0) : (!iree_codegen.null_pointer) -> () + return + } +} +// CHECK-LABEL: llvm.func @null_pointer +// CHECK: llvm.mlir.zero : !llvm.ptr