diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 05878facba3c..6ff0ee9e618d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -70,6 +70,10 @@ static bool is_AMD_WMMA(MMAIntrinsic intrinsic) { return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF; } +static bool is_AMD(MMAIntrinsic intrinsic) { + return is_AMD_MFMA(intrinsic) || is_AMD_WMMA(intrinsic); +} + static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { // Not using Wave64 at all at the moment, so the only place where the // subgroup size is 64 is on CDNA* architectures. @@ -130,6 +134,21 @@ static std::tuple getABCElementTypes(MLIRContext *context, return {}; } +/// Returns the MNK shape for an intrinsic without an implemented concrete +/// layout. +static std::tuple +getUnsupportedMNKShape(MMAIntrinsic intrinsic) { + switch (intrinsic) { + case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: + case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: + return {16, 16, 16}; + default: + assert(false && "unexpected enum value"); + return {}; + } + return {}; +} + MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, MMAFragment fragment) { switch (intrinsic) { @@ -287,16 +306,19 @@ struct OpaqueMmaLayout { Type cType; }; -template static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, - MMAIntrinsicType intrinsic) { + MMAIntrinsic intrinsic) { OpaqueMmaLayout o; std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic); - auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); - auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); - o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; - o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; - o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; + if (is_AMD(intrinsic)) { + auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); + auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); + o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; + o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; + o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; + } else { + std::tie(o.mSize, o.nSize, o.kSize) = getUnsupportedMNKShape(intrinsic); + } return o; } @@ -388,6 +410,12 @@ int64_t MMAAttr::getSubgroupSize() const { } FailureOr MMAAttr::getMmaScope() const { + // Explicit distribution currently unsupported for NV intrinsics. + MMAIntrinsic intrinsic = getIntrinsic().getValue(); + if (intrinsic == MMAIntrinsic::NV_WMMA_F16_16x16x16_F16 || + intrinsic == MMAIntrinsic::NV_WMMA_F32_16x16x16_F16) { + return failure(); + } return IREE::GPU::MMAScope::Subgroup; } @@ -856,6 +884,22 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context, return VirtualMMAAttr::get(context, intrinsicAttr); } +static std::tuple +getMNKShape(VirtualMMAIntrinsic type) { + // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved + // along the k dimension. + switch (type) { + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: + return {16, 16, 32}; + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: + case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: + return {32, 32, 16}; + } + assert(false && "unhandled virtual mma layout type."); + return {}; +} + static std::tuple getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) { Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); @@ -878,6 +922,18 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) { return {}; } +static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context, + VirtualMMAIntrinsic intrinsic) { + OpaqueMmaLayout o; + std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic); + auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs); + auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs); + o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0]; + o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1]; + o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1]; + return o; +} + std::tuple VirtualMMAAttr::getABCElementTypes() const { MLIRContext *ctx = getContext(); auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue()); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 3c06d851360b..ed993c7e4341 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -131,10 +131,14 @@ static std::optional getMmaScheduleFromProblemAndTarget( const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); SmallVector intrinsics; for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); + // Intrinsics that do not specify a scope cannot be distributed. + if (failed(mma.getMmaScope())) + continue; if (mma.getSubgroupSize() != targetSubgroupSize) continue; + + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); } if (intrinsics.empty()) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 8e6e8f949907..127c85ed37c5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -455,8 +455,8 @@ StringRef normalizeARMGPUTarget(StringRef target) { const WgpDetails *getAmpereWgpDetails() { static const MMAIntrinsic mmaOps[] = { - MMAIntrinsic::WMMA_F32_16x16x16_F16, - MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F32_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F16_16x16x16_F16, }; static const WgpDetails ampereWgp = {allComputeBits, allStorageBits, @@ -474,8 +474,8 @@ const WgpDetails *getAmpereWgpDetails() { const WgpDetails *getTuringWgpDetails() { static const MMAIntrinsic mmaOps[] = { - MMAIntrinsic::WMMA_F32_16x16x16_F16, - MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F32_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F16_16x16x16_F16, }; static const WgpDetails turingWgp = {allComputeBits, allStorageBits, @@ -493,8 +493,8 @@ const WgpDetails *getTuringWgpDetails() { const WgpDetails *getVoltaWgpDetails() { static const MMAIntrinsic mmaOps[] = { - MMAIntrinsic::WMMA_F32_16x16x16_F16, - MMAIntrinsic::WMMA_F16_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F32_16x16x16_F16, + MMAIntrinsic::NV_WMMA_F16_16x16x16_F16, }; // clang-format off static const WgpDetails voltaWgp = {