From 01597623da30cc6c667a04753f49b69ec6a8799f Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 31 Jan 2025 14:29:38 -0500 Subject: [PATCH] [Codegen][GPU] Finish splitting NV intrinsics from AMD ones (#19853) The split of `WMMA_*` enums into Nvidia and AMD variants was half finished. This completely splits the handling of each vendor. In the process, because concrete layouts for nvidia intrinsics is unimplemented, the only supported case is opaque layouts via SPIR-V. This required re-introducing `getMNKShape` per enum value rather than inferring it from the layout. This PR is effectively NFC, but unblocks enabling LLVMGPUTileAndFuse by default for matmuls. --- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 70 +++++++++++++++++-- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 8 ++- .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 12 ++-- 3 files changed, 75 insertions(+), 15 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 05878facba3c..6ff0ee9e618d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -70,6 +70,10 @@ static bool is_AMD_WMMA(MMAIntrinsic intrinsic) { return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF; } +static bool is_AMD(MMAIntrinsic intrinsic) { + return is_AMD_MFMA(intrinsic) || is_AMD_WMMA(intrinsic); +} + static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { // Not using Wave64 at all at the moment, so the only place where the // subgroup size is 64 is on CDNA* architectures. @@ -130,6 +134,21 @@ static std::tuple getABCElementTypes(MLIRContext *context, return {}; } +/// Returns the MNK shape for an intrinsic without an implemented concrete +/// layout. +static std::tuple +getUnsupportedMNKShape(MMAIntrinsic intrinsic) { + switch (intrinsic) { + case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: + case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: + return {16, 16, 16}; + default: + assert(false && "unexpected enum value"); + return {}; + } + return {}; +} + MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, MMAFragment fragment) { switch (intrinsic) { @@ -287,16 +306,19 @@ struct OpaqueMmaLayout { Type cType; }; -template static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, - MMAIntrinsicType intrinsic) { + MMAIntrinsic intrinsic) { OpaqueMmaLayout o; std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic); - auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); - auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); - o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; - o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; - o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; + if (is_AMD(intrinsic)) { + auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); + auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); + o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; + o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; + o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; + } else { + std::tie(o.mSize, o.nSize, o.kSize) = getUnsupportedMNKShape(intrinsic); + } return o; } @@ -388,6 +410,12 @@ int64_t MMAAttr::getSubgroupSize() const { } FailureOr MMAAttr::getMmaScope() const { + // Explicit distribution currently unsupported for NV intrinsics. + MMAIntrinsic intrinsic = getIntrinsic().getValue(); + if (intrinsic == MMAIntrinsic::NV_WMMA_F16_16x16x16_F16 || + intrinsic == MMAIntrinsic::NV_WMMA_F32_16x16x16_F16) { + return failure(); + } return IREE::GPU::MMAScope::Subgroup; } @@ -856,6 +884,22 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context, return VirtualMMAAttr::get(context, intrinsicAttr); } +static std::tuple +getMNKShape(VirtualMMAIntrinsic type) { + // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved + // along the k dimension. + switch (type) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + return {16, 16, 32}; + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: + return {32, 32, 16}; + } + assert(false && "unhandled virtual mma layout type."); + return {}; +} + static std::tuple getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) { Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); @@ -878,6 +922,18 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) { return {}; } +static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, + VirtualMMAIntrinsic intrinsic) { + OpaqueMmaLayout o; + std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic); + auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); + auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); + o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; + o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; + o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; + return o; +} + std::tuple VirtualMMAAttr::getABCElementTypes() const { MLIRContext *ctx = getContext(); auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue()); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 3c06d851360b..ed993c7e4341 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -131,10 +131,14 @@ static std::optional getMmaScheduleFromProblemAndTarget( const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); SmallVector intrinsics; for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); + // Intrinsics that do not specify a scope cannot be distributed. + if (failed(mma.getMmaScope())) + continue; if (mma.getSubgroupSize() != targetSubgroupSize) continue; + + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); } if (intrinsics.empty()) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 8e6e8f949907..127c85ed37c5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -455,8 +455,8 @@ StringRef normalizeARMGPUTarget(StringRef target) { const WgpDetails *getAmpereWgpDetails() { static const MMAIntrinsic mmaOps[] = { - MMAIntrinsic::WMMA_F32_16x16x16_F16, - MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F32_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F16_16x16x16_F16, }; static const WgpDetails ampereWgp = {allComputeBits, allStorageBits, @@ -474,8 +474,8 @@ const WgpDetails *getAmpereWgpDetails() { const WgpDetails *getTuringWgpDetails() { static const MMAIntrinsic mmaOps[] = { - MMAIntrinsic::WMMA_F32_16x16x16_F16, - MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F32_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F16_16x16x16_F16, }; static const WgpDetails turingWgp = {allComputeBits, allStorageBits, @@ -493,8 +493,8 @@ const WgpDetails *getTuringWgpDetails() { const WgpDetails *getVoltaWgpDetails() { static const MMAIntrinsic mmaOps[] = { - MMAIntrinsic::WMMA_F32_16x16x16_F16, - MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F32_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F16_16x16x16_F16, }; // clang-format off static const WgpDetails voltaWgp = {