Skip to content

Commit

Permalink
[AMD] NFC: Refactor DotOpMFMAConversionHelper (triton-lang#5862)
Browse files Browse the repository at this point in the history
This PR refactored `DotOpMFMAConversionHelper` by extracting utility
functions from `convertDot` to make it easier to be extended in
triton-lang#5845.
  • Loading branch information
knwng committed Feb 10, 2025
1 parent f6514ff commit f39dd8f
Showing 1 changed file with 66 additions and 56 deletions.
122 changes: 66 additions & 56 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,67 @@ struct DotOpMFMAConversionHelper {
return processSubBlocks(numSubBlocks, acc, false, true);
}

/// Dot operand layout minimal tile is kDimInstrSize elements across
/// K dimension. If dot operand K dimension is smaller, layout
/// assigns tensor elements to multiple different hardware locations.
/// In this case mfma instruction adds elements in accumulator
/// multiple times.
///
/// Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
/// Consider instruction K size is 4,
/// in this case operands will be duplicated:
/// A' = [1,2,1,2] B' = [3,4,3,4]
/// C' = (1*3+2*4) + (1*3+2*4) = 22
///
/// Following code adjusts accumulator values in such cases.
/// If accumulator is integer, shift accumulator right by
/// log2(duplicationRate). If accumulator is float, multiply accum
/// with 1/duplicationRate constant.
void adjustAccForSmallKDim(SmallVector<Value> &fc, Value &acc, Type dstElemTy,
int b, int m, int n, int64_t numRepM,
int64_t numRepN, int64_t kDimInstrSize,
int64_t kDimOperandSize,
unsigned elemsPerVec) const {
auto tb = TritonLLVMOpBuilder(loc, rewriter);
for (unsigned v = 0; v < elemsPerVec; ++v) {
Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v));
if (kDimInstrSize > kDimOperandSize) {
assert(kDimInstrSize % kDimOperandSize == 0);
int duplicationRate = kDimInstrSize / kDimOperandSize;
assert(llvm::isPowerOf2_32(duplicationRate));
if (dstElemTy.isInteger()) {
auto shiftSize = llvm::Log2_32(duplicationRate);
assert(!accElem.getType().isUnsignedInteger() &&
"MFMA uses signed accumulator");
accElem = tb.ashr(accElem, tb.i32_val(shiftSize));
} else {
auto multiplierAttr =
rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate);
auto multiplierVal =
rewriter.create<LLVM::ConstantOp>(loc, dstElemTy, multiplierAttr);
accElem = tb.fmul(accElem, multiplierVal);
}
}
auto linearIdx = b * numRepM * numRepN * elemsPerVec +
m * numRepN * elemsPerVec + n * elemsPerVec + v;
fc[linearIdx] = accElem;
}
}

void packAndReplaceResult(DotOp &op, SmallVector<Value> &fc,
FailureOr<MfmaInsn> maybeMfmaInsn, Type dstElemTy,
Type elemtTy, size_t mmaCount) const {
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
elemtTy);

rewriter.replaceOp(op, res);
}

// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
auto tb = TritonLLVMOpBuilder(loc, rewriter);
Expand Down Expand Up @@ -265,11 +326,6 @@ struct DotOpMFMAConversionHelper {
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;

Value firstMfma;
auto setFirstMfma = [&](Value mfma) {
if (!firstMfma)
firstMfma = mfma;
};

auto vecTy = vec_ty(dstElemTy, elemsPerVec);
for (int b = 0; b < numRepB; ++b) {
for (int m = 0; m < numRepM; ++m) {
Expand All @@ -291,49 +347,13 @@ struct DotOpMFMAConversionHelper {
operandA[kPack][{b, m, k}], acc)
: generateMFMAOp(mfmaInsnName, operandA[kPack][{b, m, k}],
operandB[kPack][{b, n, k}], acc);
setFirstMfma(acc);
if (!firstMfma)
firstMfma = acc;
}
}
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {
Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v));
// Dot operand layout minimal tile is kDimInstrSize elements across
// K dimension. If dot operand K dimension is smaller, layout
// assigns tensor elements to multiple different hardware locations.
// In this case mfma instruction adds elements in accumulator
// multiple times.
//
// Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
// Consider instruction K size is 4,
// in this case operands will be duplicated:
// A' = [1,2,1,2] B' = [3,4,3,4]
// C' = (1*3+2*4) + (1*3+2*4) = 22
//
// Following code adjusts accumulator values in such cases.
// If accumulator is integer, shift accumulator right by
// log2(duplicationRate). If accumulator is float, multiply accum
// with 1/duplicationRate constant.
if (kDimInstrSize > kDimOperandSize) {
assert(kDimInstrSize % kDimOperandSize == 0);
int duplicationRate = kDimInstrSize / kDimOperandSize;
assert(llvm::isPowerOf2_32(duplicationRate));
if (dstElemTy.isInteger()) {
auto shiftSize = llvm::Log2_32(duplicationRate);
assert(!accElem.getType().isUnsignedInteger() &&
"MFMA uses signed accumulator");
accElem = tb.ashr(accElem, tb.i32_val(shiftSize));
} else {
auto multiplierAttr =
rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate);
auto multiplierVal = rewriter.create<LLVM::ConstantOp>(
loc, dstElemTy, multiplierAttr);
accElem = tb.fmul(accElem, multiplierVal);
}
}
auto linearIdx = b * numRepM * numRepN * elemsPerVec +
m * numRepN * elemsPerVec + n * elemsPerVec + v;
fc[linearIdx] = accElem;
}
adjustAccForSmallKDim(fc, acc, dstElemTy, b, m, n, numRepM, numRepN,
kDimInstrSize, kDimOperandSize, elemsPerVec);
}
}
}
Expand All @@ -347,19 +367,9 @@ struct DotOpMFMAConversionHelper {
if (setPrioOp && firstMfma)
setPrioOp->moveAfter(firstMfma.getDefiningOp());

// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

Type elemtTy = elemTyA;
const size_t mmaCount =
numRepB * numRepM * numRepN * numRepK * kWidth / kBase;
setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
elemtTy);

rewriter.replaceOp(op, res);
packAndReplaceResult(op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);

return success();
}
Expand Down

0 comments on commit f39dd8f

Please sign in to comment.