Skip to content

Commit

Permalink
[AMD] Add MFMA and WMMA layouts to LinearEncodingTest
Browse files Browse the repository at this point in the history
This PR adds AMD specific layouts to LinearEncodingTest::DistributedEncodingToLinearEncoding test
and fixes few issues exposed by this test.
  • Loading branch information
binarman committed Jan 24, 2025
1 parent 5f34fcf commit b6a351c
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 47 deletions.
11 changes: 9 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ Row |
}
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getElemsPerInstrForOperands(int opIdx, int kWidth) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
Expand All @@ -1079,7 +1079,11 @@ Row |
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);
if (getVersion() == 2) {
contigPerThread[rank - 2] = 8;
if (getIsTransposed()) {
contigPerThread[rank - 1] = 8;
} else {
contigPerThread[rank - 2] = 8;
}
}
return contigPerThread;
};
Expand Down Expand Up @@ -1339,6 +1343,9 @@ vecIdx (index of the element in the quad; this is always along the k-dim)
contigPerThread[rank - 1] = kWidth;
else
contigPerThread[rank - 2] = kWidth;
if (auto wmma = mlir::dyn_cast<AMDWmmaEncodingAttr>(getParent())) {
assert(wmma.getVersion() != 1 && "WMMA v1 currently not implemented");
}
return contigPerThread;
};
}];
Expand Down
37 changes: 31 additions & 6 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,13 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth;
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
return elemsPerThread;
} else if (auto wmma = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
auto rep = wmma.getRepForOperand(shape, eltTy, kWidth, idx);
if (rank == 3)
elemsPerThread[0] = rep[0];
elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth;
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
return elemsPerThread;
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
assert(getCTALayout(*this) ==
CTALayoutAttr::getDefault(getContext(), rank) &&
Expand Down Expand Up @@ -986,8 +993,13 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
if (mlir::isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(getParent())) {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ false);
} else {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}
}

LogicalResult DotOperandEncodingAttr::verify(
Expand Down Expand Up @@ -2126,7 +2138,8 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
auto rank = getWarpsPerCTA().size();
SmallVector<unsigned> sizePerThread(rank, 1);
auto numReplicated = getVersion() == 1 ? 2 : 1;
auto elemsPerInstr = numReplicated * product(getElemsPerInstrForOperands()) /
auto elemsPerInstr = numReplicated *
product(getElemsPerInstrForOperands(opIdx, kWidth)) /
product(getThreadsPerWarp());
if (opIdx == 0) {
sizePerThread[rank - 2] = 1;
Expand All @@ -2146,15 +2159,26 @@ unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand(
return product(rep) * kWidth;
}

SmallVector<int64_t> AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const {
return {16, 16};
SmallVector<int64_t>
AMDWmmaEncodingAttr::getElemsPerInstrForOperands(int opIdx, int kWidth) const {
int64_t nonKSize = 16;
int64_t kSize = getVersion() == 1 ? kWidth : kWidth * 2;
switch (opIdx) {
case 0:
return {nonKSize, kSize};
case 1:
return {kSize, nonKSize};
default:
assert(false && "opidx should be 0 or 1");
}
return {};
}

SmallVector<int64_t>
AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth,
int opIdx) const {
auto operandTileShape = getElemsPerInstrForOperands();
auto operandTileShape = getElemsPerInstrForOperands(opIdx, kWidth);
assert(operandTileShape.size() == 2);
auto warpsPerCTA = getWarpsPerCTA();
auto rank = operandShape.size();
Expand All @@ -2177,6 +2201,7 @@ AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
}
}

// TODO: add kWidth argument
SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerInstr() {
// TODO: move magic numbers out of the code
return {16, 16, 16};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,

auto elemTy = aTensorTy.getElementType();
int kWidth = encoding.getKWidth();
auto elemsPerInstr = wmmaLayout.getElemsPerInstrForOperands();
auto elemsPerInstr = wmmaLayout.getElemsPerInstrForOperands(opIdx, kWidth);
auto wmmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0];
auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1];
assert(wmmaInstrNonK == 16);
Expand Down
158 changes: 120 additions & 38 deletions unittest/Dialect/TritonGPU/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,65 +267,88 @@ INSTANTIATE_TEST_SUITE_P(
R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)"},
})));

class AMDLayoutTest : public ::testing::Test {
class AMDLayoutHelper {
public:
AMDLayoutTest() {
ctx.getOrLoadDialect<TritonGPUDialect>();
ctaLayout =
triton::gpu::CTALayoutAttr::get(&ctx, ctaPerCGA, ctaSplit, ctaOrder);
f16Ty = FloatType::getF16(&ctx);
void commonInitialization() {
ctx->getOrLoadDialect<TritonGPUDialect>();
ctaLayout3d = triton::gpu::CTALayoutAttr::get(ctx, /*ctaPerCGA*/ {1, 1, 1},
/*ctaSplit*/ {1, 1, 1},
/*ctaOrder*/ {2, 1, 0});
ctaLayout2d = triton::gpu::CTALayoutAttr::get(
ctx, /*ctaPerCGA*/ {1, 1}, /*ctaSplit*/ {1, 1}, /*ctaOrder*/ {1, 0});
f16Ty = FloatType::getF16(ctx);
}

triton::gpu::DotOperandEncodingAttr
createDotOperand(int idx, Attribute parent, int kWidth) {
return triton::gpu::DotOperandEncodingAttr::get(&ctx, idx, parent, kWidth);
AMDLayoutHelper(MLIRContext *externalCtx) {
ctx = externalCtx;
commonInitialization();
}

protected:
MLIRContext ctx;
const SmallVector<unsigned> ctaPerCGA{1, 1, 1};
const SmallVector<unsigned> ctaSplit{1, 1, 1};
const SmallVector<unsigned> ctaOrder{2, 1, 0};
triton::gpu::CTALayoutAttr ctaLayout;
Type f16Ty;
};
// This constructor creates MLIR context internally,
// pointer hold by unique_ptr, so destructor will automatically delete it.
AMDLayoutHelper() {
guardCtx = std::make_unique<MLIRContext>();
ctx = guardCtx.get();
commonInitialization();
}

class AMDMfmaLayoutTest : public AMDLayoutTest {
public:
AMDMfmaLayoutTest() = default;
triton::gpu::CTALayoutAttr getCTALayout(ArrayRef<unsigned> warpsPerCTA) {
switch (warpsPerCTA.size()) {
case 3:
return ctaLayout3d;
case 2:
return ctaLayout2d;
default:
assert(false && "unsupported rank for mma layout");
}
return nullptr;
}

triton::gpu::DotOperandEncodingAttr
createDotOperand(int idx, Attribute parent, int kWidth) {
return triton::gpu::DotOperandEncodingAttr::get(ctx, idx, parent, kWidth);
}

triton::gpu::AMDMfmaEncodingAttr createMFMA(int mDim, int nDim,
ArrayRef<unsigned> warpsPerCTA) {
return triton::gpu::AMDMfmaEncodingAttr::get(
&ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim,
/*isTransposed=*/false, ctaLayout);
ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim,
/*isTransposed=*/false, getCTALayout(warpsPerCTA));
}

triton::gpu::AMDMfmaEncodingAttr
createTransposedMFMA(int mDim, int nDim, ArrayRef<unsigned> warpsPerCTA) {
return triton::gpu::AMDMfmaEncodingAttr::get(
&ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim,
/*isTransposed=*/true, ctaLayout);
ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim,
/*isTransposed=*/true, getCTALayout(warpsPerCTA));
}
};

class AMDWmmaLayoutTest : public AMDLayoutTest {
public:
AMDWmmaLayoutTest() = default;

triton::gpu::AMDWmmaEncodingAttr
createWMMAv1(ArrayRef<unsigned> warpsPerCTA) {
return triton::gpu::AMDWmmaEncodingAttr::get(
&ctx, /*version=*/1, /*isTransposed=*/false, warpsPerCTA, ctaLayout);
ctx, /*version=*/1, /*isTransposed=*/false, warpsPerCTA,
getCTALayout(warpsPerCTA));
}

triton::gpu::AMDWmmaEncodingAttr
createWMMAv2(bool isTransposed, ArrayRef<unsigned> warpsPerCTA) {
return triton::gpu::AMDWmmaEncodingAttr::get(
&ctx, /*version=*/2, isTransposed, warpsPerCTA, ctaLayout);
return triton::gpu::AMDWmmaEncodingAttr::get(ctx, /*version=*/2,
isTransposed, warpsPerCTA,
getCTALayout(warpsPerCTA));
}

protected:
std::unique_ptr<MLIRContext> guardCtx;
MLIRContext *ctx;
triton::gpu::CTALayoutAttr ctaLayout2d;
triton::gpu::CTALayoutAttr ctaLayout3d;
Type f16Ty;
};

class AMDMfmaLayoutTest : public ::testing::Test, public AMDLayoutHelper {};

class AMDWmmaLayoutTest : public ::testing::Test, public AMDLayoutHelper {};

TEST_F(AMDMfmaLayoutTest, mfma32) {
auto mfma2d = createMFMA(32, 32, {2, 4});
ASSERT_THAT(mfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u));
Expand Down Expand Up @@ -538,6 +561,44 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
}
}

// Create an MFMA and DotOperandEncodingAttr
{
AMDLayoutHelper h(&ctx);
auto mfma16 = h.createMFMA(16, 16, {2, 2});
auto mfma32 = h.createMFMA(32, 32, {1, 2});
auto mfma16t = h.createTransposedMFMA(16, 16, {2, 2});
auto mfma32t = h.createTransposedMFMA(32, 32, {1, 2});
distributedEncodings.push_back(mfma16);
distributedEncodings.push_back(mfma32);
distributedEncodings.push_back(mfma16t);
distributedEncodings.push_back(mfma32t);
// Create an opIdx=0 and opIdx=1 encoding
for (unsigned opIdx = 0; opIdx < 2; ++opIdx) {
distributedEncodings.push_back(h.createDotOperand(opIdx, mfma16, 8));
distributedEncodings.push_back(h.createDotOperand(opIdx, mfma32, 4));
// Skip operands for transposed layouts,
// because they have same layout as non transpose variant
}
}

// Create an WMMA and DotOperandEncodingAttr
{
AMDLayoutHelper h(&ctx);
auto wmma1 = h.createWMMAv1({2, 2});
auto wmma2 = h.createWMMAv2(false, {2, 2});
auto wmma2t = h.createWMMAv2(true, {2, 2});
distributedEncodings.push_back(wmma1);
distributedEncodings.push_back(wmma2);
distributedEncodings.push_back(wmma2t);
// Create an opIdx=0 and opIdx=1 encoding
for (unsigned opIdx = 0; opIdx < 2; ++opIdx) {
distributedEncodings.push_back(h.createDotOperand(opIdx, wmma1, 16));
distributedEncodings.push_back(h.createDotOperand(opIdx, wmma2, 16));
// Skip operands for transposed layouts,
// because they have same layout as non transpose variant
}
}

for (const auto &distributedEncoding : distributedEncodings) {
for (auto shape : shapes) {
if (auto sliceEncoding =
Expand All @@ -558,19 +619,40 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
// Test that methods of DistributedEncoding return the same values
Type eltTy = FloatType::getF32(&ctx);

ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
// LinearLayout::getRepOrder works for some layouts,
// but gives unexpected order for MFMA and WMMA layouts
// TODO remove or rework this check
if (!mlir::isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(
distributedEncoding)) {
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
ASSERT_EQ(distributedEncoding.getRepOrder(),
linearEncoding.getRepOrder());
}

ASSERT_EQ(cast<triton::gpu::TritonGPU_AttrTrait>(distributedEncoding)
.getTotalElemsPerThread(shape, eltTy),
linearEncoding.getTotalElemsPerThread(shape, eltTy));
ASSERT_EQ(cast<triton::gpu::TritonGPU_AttrTrait>(distributedEncoding)
.getElemsPerThread(shape, eltTy),
linearEncoding.getElemsPerThread(shape, eltTy));
ASSERT_EQ(distributedEncoding.getRepOrder(),
linearEncoding.getRepOrder());
ASSERT_EQ(distributedEncoding.getContigPerThread(),
linearEncoding.getContigPerThread());

auto dotOperand =
mlir::dyn_cast<DotOperandEncodingAttr>(distributedEncoding);
// Current implementation of WMMA v1 dot operand requires kWidth to be
// equal to number of elements processed by one WMMA instruction, which is
// a fixed number. At the same time, dot operand goes continuously along K
// dim, which depends on shape.
bool wmmaV1DotOp = false;
if (dotOperand) {
auto wmma = mlir::dyn_cast<AMDWmmaEncodingAttr>(dotOperand.getParent());
wmmaV1DotOp = wmma && wmma.getVersion() == 1;
}
if (!wmmaV1DotOp) {
ASSERT_EQ(distributedEncoding.getContigPerThread(),
linearEncoding.getContigPerThread());
}
// DotOperandEncodingAttr::getWarpOrder() is not defined
if (!isa<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) {
if (!dotOperand) {
ASSERT_EQ(distributedEncoding.getWarpOrder(),
linearEncoding.getWarpOrder());
}
Expand Down

0 comments on commit b6a351c

Please sign in to comment.