Skip to content

Commit

Permalink
[Codegen][GPU] Finish splitting NV intrinsics from AMD ones (#19853)
Browse files Browse the repository at this point in the history
The split of `WMMA_*` enums into Nvidia and AMD variants was half
finished. This completely splits the handling of each vendor. In the
process, because concrete layouts for nvidia intrinsics is
unimplemented, the only supported case is opaque layouts via SPIR-V.
This required re-introducing `getMNKShape` per enum value rather than
inferring it from the layout.

This PR is effectively NFC, but unblocks enabling LLVMGPUTileAndFuse by
default for matmuls.
  • Loading branch information
qedawkins authored Jan 31, 2025
1 parent b9555fc commit 0159762
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 15 deletions.
70 changes: 63 additions & 7 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {
return getArchID(intrinsic) >= 0x1800 && getArchID(intrinsic) <= 0x1FFF;
}

static bool is_AMD(MMAIntrinsic intrinsic) {
return is_AMD_MFMA(intrinsic) || is_AMD_WMMA(intrinsic);
}

static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
// Not using Wave64 at all at the moment, so the only place where the
// subgroup size is 64 is on CDNA* architectures.
Expand Down Expand Up @@ -130,6 +134,21 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
return {};
}

/// Returns the MNK shape for an intrinsic without an implemented concrete
/// layout.
static std::tuple<int64_t, int64_t, int64_t>
getUnsupportedMNKShape(MMAIntrinsic intrinsic) {
switch (intrinsic) {
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
return {16, 16, 16};
default:
assert(false && "unexpected enum value");
return {};
}
return {};
}

MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
MMAFragment fragment) {
switch (intrinsic) {
Expand Down Expand Up @@ -287,16 +306,19 @@ struct OpaqueMmaLayout {
Type cType;
};

template <typename MMAIntrinsicType>
static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
MMAIntrinsicType intrinsic) {
MMAIntrinsic intrinsic) {
OpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
if (is_AMD(intrinsic)) {
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
} else {
std::tie(o.mSize, o.nSize, o.kSize) = getUnsupportedMNKShape(intrinsic);
}
return o;
}

Expand Down Expand Up @@ -388,6 +410,12 @@ int64_t MMAAttr::getSubgroupSize() const {
}

FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
// Explicit distribution currently unsupported for NV intrinsics.
MMAIntrinsic intrinsic = getIntrinsic().getValue();
if (intrinsic == MMAIntrinsic::NV_WMMA_F16_16x16x16_F16 ||
intrinsic == MMAIntrinsic::NV_WMMA_F32_16x16x16_F16) {
return failure();
}
return IREE::GPU::MMAScope::Subgroup;
}

Expand Down Expand Up @@ -856,6 +884,22 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context,
return VirtualMMAAttr::get(context, intrinsicAttr);
}

static std::tuple<int64_t, int64_t, int64_t>
getMNKShape(VirtualMMAIntrinsic type) {
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
// along the k dimension.
switch (type) {
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
return {16, 16, 32};
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
return {32, 32, 16};
}
assert(false && "unhandled virtual mma layout type.");
return {};
}

static std::tuple<Type, Type, Type>
getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
Expand All @@ -878,6 +922,18 @@ getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
return {};
}

static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
VirtualMMAIntrinsic intrinsic) {
OpaqueMmaLayout o;
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
auto rhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Rhs);
o.mSize = lhs.outer[0] * lhs.thread[0] * lhs.element[0];
o.kSize = lhs.outer[1] * lhs.thread[1] * lhs.element[1];
o.nSize = rhs.outer[1] * rhs.thread[1] * rhs.element[1];
return o;
}

std::tuple<Type, Type, Type> VirtualMMAAttr::getABCElementTypes() const {
MLIRContext *ctx = getContext();
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,14 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
const int64_t targetSubgroupSize = target.getPreferredSubgroupSize();
SmallVector<GPUMatmulShapeType> intrinsics;
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
// Intrinsics that do not specify a scope cannot be distributed.
if (failed(mma.getMmaScope()))
continue;
if (mma.getSubgroupSize() != targetSubgroupSize)
continue;

auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
}
if (intrinsics.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ StringRef normalizeARMGPUTarget(StringRef target) {

const WgpDetails *getAmpereWgpDetails() {
static const MMAIntrinsic mmaOps[] = {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::NV_WMMA_F32_16x16x16_F16,
MMAIntrinsic::NV_WMMA_F16_16x16x16_F16,
};
static const WgpDetails ampereWgp = {allComputeBits,
allStorageBits,
Expand All @@ -474,8 +474,8 @@ const WgpDetails *getAmpereWgpDetails() {

const WgpDetails *getTuringWgpDetails() {
static const MMAIntrinsic mmaOps[] = {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::NV_WMMA_F32_16x16x16_F16,
MMAIntrinsic::NV_WMMA_F16_16x16x16_F16,
};
static const WgpDetails turingWgp = {allComputeBits,
allStorageBits,
Expand All @@ -493,8 +493,8 @@ const WgpDetails *getTuringWgpDetails() {

const WgpDetails *getVoltaWgpDetails() {
static const MMAIntrinsic mmaOps[] = {
MMAIntrinsic::WMMA_F32_16x16x16_F16,
MMAIntrinsic::WMMA_F16_16x16x16_F16,
MMAIntrinsic::NV_WMMA_F32_16x16x16_F16,
MMAIntrinsic::NV_WMMA_F16_16x16x16_F16,
};
// clang-format off
static const WgpDetails voltaWgp = {
Expand Down

0 comments on commit 0159762

Please sign in to comment.