Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Refactor DotOpMFMAConversionHelper #5862

Merged
merged 1 commit into from
Feb 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 66 additions & 56 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,67 @@ struct DotOpMFMAConversionHelper {
return processSubBlocks(numSubBlocks, acc, false, true);
}

/// Dot operand layout minimal tile is kDimInstrSize elements across
/// K dimension. If dot operand K dimension is smaller, layout
/// assigns tensor elements to multiple different hardware locations.
/// In this case mfma instruction adds elements in accumulator
/// multiple times.
///
/// Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
/// Consider instruction K size is 4,
/// in this case operands will be duplicated:
/// A' = [1,2,1,2] B' = [3,4,3,4]
/// C' = (1*3+2*4) + (1*3+2*4) = 22
///
/// Following code adjusts accumulator values in such cases.
/// If accumulator is integer, shift accumulator right by
/// log2(duplicationRate). If accumulator is float, multiply accum
/// with 1/duplicationRate constant.
void adjustAccForSmallKDim(SmallVector<Value> &fc, Value &acc, Type dstElemTy,
int b, int m, int n, int64_t numRepM,
int64_t numRepN, int64_t kDimInstrSize,
int64_t kDimOperandSize,
unsigned elemsPerVec) const {
auto tb = TritonLLVMOpBuilder(loc, rewriter);
for (unsigned v = 0; v < elemsPerVec; ++v) {
Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v));
if (kDimInstrSize > kDimOperandSize) {
assert(kDimInstrSize % kDimOperandSize == 0);
int duplicationRate = kDimInstrSize / kDimOperandSize;
assert(llvm::isPowerOf2_32(duplicationRate));
if (dstElemTy.isInteger()) {
auto shiftSize = llvm::Log2_32(duplicationRate);
assert(!accElem.getType().isUnsignedInteger() &&
"MFMA uses signed accumulator");
accElem = tb.ashr(accElem, tb.i32_val(shiftSize));
} else {
auto multiplierAttr =
rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate);
auto multiplierVal =
rewriter.create<LLVM::ConstantOp>(loc, dstElemTy, multiplierAttr);
accElem = tb.fmul(accElem, multiplierVal);
}
}
auto linearIdx = b * numRepM * numRepN * elemsPerVec +
m * numRepN * elemsPerVec + n * elemsPerVec + v;
fc[linearIdx] = accElem;
}
}

void packAndReplaceResult(DotOp &op, SmallVector<Value> &fc,
FailureOr<MfmaInsn> maybeMfmaInsn, Type dstElemTy,
Type elemtTy, size_t mmaCount) const {
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
elemtTy);

rewriter.replaceOp(op, res);
}

// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
auto tb = TritonLLVMOpBuilder(loc, rewriter);
Expand Down Expand Up @@ -243,11 +304,6 @@ struct DotOpMFMAConversionHelper {
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;

Value firstMfma;
auto setFirstMfma = [&](Value mfma) {
if (!firstMfma)
firstMfma = mfma;
};

auto vecTy = vec_ty(dstElemTy, elemsPerVec);
for (int b = 0; b < numRepB; ++b) {
for (int m = 0; m < numRepM; ++m) {
Expand All @@ -269,49 +325,13 @@ struct DotOpMFMAConversionHelper {
operandA[kPack][{b, m, k}], acc)
: generateMFMAOp(mfmaInsnName, operandA[kPack][{b, m, k}],
operandB[kPack][{b, n, k}], acc);
setFirstMfma(acc);
if (!firstMfma)
firstMfma = acc;
}
}
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {
Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v));
// Dot operand layout minimal tile is kDimInstrSize elements across
// K dimension. If dot operand K dimension is smaller, layout
// assigns tensor elements to multiple different hardware locations.
// In this case mfma instruction adds elements in accumulator
// multiple times.
//
// Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
// Consider instruction K size is 4,
// in this case operands will be duplicated:
// A' = [1,2,1,2] B' = [3,4,3,4]
// C' = (1*3+2*4) + (1*3+2*4) = 22
//
// Following code adjusts accumulator values in such cases.
// If accumulator is integer, shift accumulator right by
// log2(duplicationRate). If accumulator is float, multiply accum
// with 1/duplicationRate constant.
if (kDimInstrSize > kDimOperandSize) {
assert(kDimInstrSize % kDimOperandSize == 0);
int duplicationRate = kDimInstrSize / kDimOperandSize;
assert(llvm::isPowerOf2_32(duplicationRate));
if (dstElemTy.isInteger()) {
auto shiftSize = llvm::Log2_32(duplicationRate);
assert(!accElem.getType().isUnsignedInteger() &&
"MFMA uses signed accumulator");
accElem = tb.ashr(accElem, tb.i32_val(shiftSize));
} else {
auto multiplierAttr =
rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate);
auto multiplierVal = rewriter.create<LLVM::ConstantOp>(
loc, dstElemTy, multiplierAttr);
accElem = tb.fmul(accElem, multiplierVal);
}
}
auto linearIdx = b * numRepM * numRepN * elemsPerVec +
m * numRepN * elemsPerVec + n * elemsPerVec + v;
fc[linearIdx] = accElem;
}
adjustAccForSmallKDim(fc, acc, dstElemTy, b, m, n, numRepM, numRepN,
kDimInstrSize, kDimOperandSize, elemsPerVec);
}
}
}
Expand All @@ -325,19 +345,9 @@ struct DotOpMFMAConversionHelper {
if (setPrioOp && firstMfma)
setPrioOp->moveAfter(firstMfma.getDefiningOp());

// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

Type elemtTy = elemTyA;
const size_t mmaCount =
numRepB * numRepM * numRepN * numRepK * kWidth / kBase;
setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
elemtTy);

rewriter.replaceOp(op, res);
packAndReplaceResult(op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);

return success();
}
Expand Down
Loading