From 0753712ab79b0e165c935f977c86a68de48e6607 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Wed, 29 Jan 2025 19:48:03 +0100 Subject: [PATCH 1/5] [backend] NFC: Split architecture dependant and independant parts of FMA dot conversion (#5655) This PR splits FMA dot conversion from Triton GPU to LLVM in two parts: - Common code with iteration across M/N dim - Architecture dependant scalar multiplication of vectos across K dim This PR do not introduce any test, because it does not fix any bugs or introduce new functionality, it just refactors code. --- .../TritonGPUToLLVM/FMADotUtility.h | 35 ++++ lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 152 +++------------- .../DotOpToLLVM/FMADotUtility.cpp | 165 ++++++++++++++++++ 4 files changed, 224 insertions(+), 129 deletions(-) create mode 100644 include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h create mode 100644 lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp diff --git a/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h b/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h new file mode 100644 index 0000000000..907d36ed45 --- /dev/null +++ b/include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h @@ -0,0 +1,35 @@ +#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H +#define TRITON_CONVERSION_FMA_DOT_UTILITY_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::gpu { + +/// Abstract interface for scalar multiplication of Value vectors. +/// +/// Enable generation of hardware specific code in different backends. +class FMAVectorMultiplier { +public: + /// \returns scalar product of two arrays, plus c: a·b + c + virtual Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) = 0; + + virtual ~FMAVectorMultiplier() = default; +}; + +/// Implements a framework for FMA dot conversion to llvm. +/// +/// This function implements architecture independent part of FMA dot +/// conversion and calls "multiplier" object, which is defined by caller +/// and implements architecture dependant part of conversion. +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier); + +} // namespace mlir::triton::gpu + +#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 273f0e426d..dc21142d4c 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp DotOpToLLVM/FMA.cpp + DotOpToLLVM/FMADotUtility.cpp AllocateSharedMemory.cpp AssertOpToLLVM.cpp ControlFlowOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index ecf1d12914..0d6a0cad3d 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -1,144 +1,38 @@ -#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; using namespace mlir::triton; using namespace ::mlir::triton::gpu; -using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::expandMatrixOrderWithBatch; -using ::mlir::triton::gpu::expandMatrixShapeWithBatch; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getSizePerThread; - -/// \brief spatial position of repetition and register of a given value -struct OperandValueKey { - unsigned bRepIdx, nonKRepIdx; - unsigned bIdx, nonKIdx, kIdx; - - bool operator==(const OperandValueKey &other) const { - return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && - bIdx == other.bIdx && nonKIdx == other.nonKIdx && - kIdx == other.kIdx); - } -}; - -template <> struct std::hash { - std::size_t operator()(const OperandValueKey &k) const { - return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, - k.kIdx); +namespace { +class GenericFMAVectorMultiplier : public FMAVectorMultiplier { + OpBuilder &builder; + Location loc; + +public: + GenericFMAVectorMultiplier(OpBuilder &builder, Location loc) + : builder(builder), loc(loc) {} + + Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) override { + auto K = a.size(); + assert(b.size() == K); + Value accum = c; + for (auto [aElem, bElem] : llvm::zip(a, b)) + accum = builder.create(loc, aElem, bElem, accum); + return accum; } }; -using ValueTableFMA = std::unordered_map; - -static ValueTableFMA getValueTableFromStructFMA( - Value val, ArrayRef perRepShape, ArrayRef repetitions, - unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, - Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { - ValueTableFMA res; - auto elems = unpackLLElements(loc, val, rewriter); - assert(perRepShape.size() == 3); - auto numElemsRep = product(perRepShape); - assert(elems.size() == numElemsRep * product(repetitions)); - assert(kDim == 1 || kDim == 2); - assert(nonKDim == 1 || nonKDim == 2); - const unsigned bDim = 0; +} // namespace - for (unsigned idx = 0; idx < elems.size(); ++idx) { - auto inRepLinearIdx = idx % numElemsRep; - auto repLinearIdx = idx / numElemsRep; - auto inRepSpatialIdx = - mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); - auto repSpatialIdx = - mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); - OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], - inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], - inRepSpatialIdx[kDim]}; - res[key] = elems[idx]; - } - return res; -} - -LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { auto *ctx = rewriter.getContext(); auto loc = op.getLoc(); - - auto A = op.getA(); - auto D = op.getResult(); - - auto aTensorTy = cast(A.getType()); - auto dTensorTy = cast(D.getType()); - - SmallVector aShapePerCTA = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); - auto dShapePerCTA = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); - - BlockedEncodingAttr dLayout = - cast(dTensorTy.getEncoding()); - // TODO process A and B operand separately - auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); - auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); - auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); - - Value llA = adaptor.getA(); - Value llB = adaptor.getB(); - - auto sizePerThread = - expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); - auto numElemsPerThread = product(sizePerThread); - auto shapePerCTATile = - expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); - - unsigned K = aShapePerCTA[2]; - - unsigned threadTileShape[3]; - unsigned repetitions[3]; - for (int i = 0; i < 3; ++i) { - repetitions[i] = - ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); - } - - auto has = getValueTableFromStructFMA( - llA, {sizePerThread[0], sizePerThread[1], K}, - {repetitions[0], repetitions[1], 1}, - /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); - auto hbs = getValueTableFromStructFMA( - llB, {sizePerThread[0], K, sizePerThread[2]}, - {repetitions[0], 1, repetitions[2]}, - /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); - - SmallVector acc = cc; - - for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) - for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) - for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) - for (unsigned b = 0; b < sizePerThread[0]; ++b) - for (unsigned m = 0; m < sizePerThread[1]; ++m) - for (unsigned n = 0; n < sizePerThread[2]; ++n) { - SmallVector multiDimAccumIdx = {b, m, n}; - unsigned linearInRepIdx = - linearize(multiDimAccumIdx, sizePerThread, inRepOrder); - SmallVector multiDimRepIdx = {bRep, mRep, nRep}; - unsigned linearRepIdx = - linearize(multiDimRepIdx, repetitions, repOrder); - unsigned linearAccumIdx = - linearInRepIdx + linearRepIdx * numElemsPerThread; - for (unsigned k = 0; k < K; ++k) { - auto aOp = has[{bRep, mRep, b, m, k}]; - auto bOp = hbs[{bRep, nRep, b, n, k}]; - acc[linearAccumIdx] = rewriter.create( - loc, aOp, bOp, acc[linearAccumIdx]); - } - } - - auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); - rewriter.replaceOp(op, res); - - return success(); + GenericFMAVectorMultiplier multiplier(rewriter, loc); + return parametricConvertFMADot(op, adaptor, typeConverter, rewriter, + multiplier); } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp new file mode 100644 index 0000000000..f61b723c08 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -0,0 +1,165 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; + +namespace { + +/// OperandValueKey structure represents compile time part +/// of spatial coordinates of a value in a tensor. +/// +/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be +/// defined as: +/// +/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord) +/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord) +/// k = kIdx +/// +/// Where: +/// CTABSize, CTANKSize: constants; +/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components; +/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components. +struct OperandValueKey { + unsigned bRepIdx, nonKRepIdx; + unsigned bIdx, nonKIdx, kIdx; + + bool operator==(const OperandValueKey &other) const { + return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && + bIdx == other.bIdx && nonKIdx == other.nonKIdx && + kIdx == other.kIdx); + } +}; + +} // namespace + +template <> struct std::hash { + std::size_t operator()(const OperandValueKey &k) const { + return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, + k.kIdx); + } +}; + +namespace { + +using ValueTableFMA = std::unordered_map; + +ValueTableFMA getValueTableFromStructFMA( + Value val, ArrayRef perRepShape, ArrayRef repetitions, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, + Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + assert(perRepShape.size() == 3); + auto numElemsRep = product(perRepShape); + assert(elems.size() == numElemsRep * product(repetitions)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto inRepLinearIdx = idx % numElemsRep; + auto repLinearIdx = idx / numElemsRep; + auto inRepSpatialIdx = + mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); + auto repSpatialIdx = + mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); + OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], + inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], + inRepSpatialIdx[kDim]}; + res[key] = elems[idx]; + } + return res; +} + +} // namespace + +namespace mlir::triton::gpu { + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto dTensorTy = cast(D.getType()); + + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + // TODO process A and B operand separately + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto numElemsPerThread = product(sizePerThread); + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); + + unsigned K = aShapePerCTA[2]; + + unsigned threadTileShape[3]; + unsigned repetitions[3]; + for (int i = 0; i < 3; ++i) { + repetitions[i] = + ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); + } + + auto has = getValueTableFromStructFMA( + llA, {sizePerThread[0], sizePerThread[1], K}, + {repetitions[0], repetitions[1], 1}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); + auto hbs = getValueTableFromStructFMA( + llB, {sizePerThread[0], K, sizePerThread[2]}, + {repetitions[0], 1, repetitions[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); + + SmallVector acc = cc; + + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearInRepIdx = + LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + unsigned linearAccumIdx = + linearInRepIdx + linearRepIdx * numElemsPerThread; + + SmallVector aOpVector; + SmallVector bOpVector; + + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + acc[linearAccumIdx] = multiplier.multiplyVectors( + aOpVector, bOpVector, acc[linearAccumIdx]); + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} + +} // namespace mlir::triton::gpu From b9eda84284762ec450d0b333f2ce35624bcd76a6 Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Wed, 29 Jan 2025 21:00:50 +0100 Subject: [PATCH 2/5] [BACKEND] Limit vector size to scratch size for convert_layout (#5746) Without this, we can get into a situation when the vector loads/stores would exceed the size of the scratch buffer (and trigger an assertion). Fixes #5745. --- lib/Analysis/Allocation.cpp | 6 ++++++ test/Conversion/tritongpu_to_llvm.mlir | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 14563810e2..9e45ebe6aa 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, std::tie(scratchConfig.inVec, scratchConfig.outVec) = getScratchCvtInOutVecLengths(srcTy, dstTy); + // We can't write a longer vector than the shape of shared memory. + // This shape might be smaller than the tensor shape in case we decided to + // do the conversion in multiple iterations. + unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]]; + scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim); + scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim); // No padding is required if the tensor is 1-D, or if all dimensions except // the first accessed dimension have a size of 1. diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 4cdcbc85c5..f60521d426 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1266,6 +1266,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // ----- +// Regression test for https://github.com/triton-lang/triton/issues/5745 +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], warp = [[1, 0], [2, 0], [4, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 2]], lane = [[0, 0], [0, 0], [0, 0], [0, 0], [1, 0]], warp = [[2, 0], [4, 0], [0, 1]], block = []}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { + // CHECK-LABEL: linear_layout_with_multiple_iterations + tt.func @linear_layout_with_multiple_iterations(%src: tensor<8x4xbf16, #linear>) { + %cvt = ttg.convert_layout %src : tensor<8x4xbf16, #linear> -> tensor<8x4xbf16, #linear1> + // CHECK: inline_asm{{.*}}st.shared.v2 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK: nvvm.barrier0 + // CHECK: inline_asm{{.*}}st.shared.v2 + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> From 1186806d96be85e1d95d2535319f5e257024490a Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Wed, 29 Jan 2025 22:10:36 +0000 Subject: [PATCH 3/5] [LAYOUTS] Create a trait that implements Layout equality by comparing the LLs (#5747) As per title --------- Co-authored-by: Mogball --- include/triton/Dialect/Triton/IR/Traits.h | 3 ++- .../Dialect/Triton/IR/TritonInterfaces.td | 14 +++++++++++++ include/triton/Dialect/Triton/IR/TritonOps.td | 2 +- lib/Dialect/Triton/IR/Ops.cpp | 20 ------------------ lib/Dialect/Triton/IR/Traits.cpp | 21 +++++++++++++++++++ 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index 804b1648e9..dbbf876cb5 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -3,6 +3,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LogicalResult.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -27,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op); LogicalResult verifySameOperandsEncoding(Operation *op, bool allowTensorPointerType = false); - +LogicalResult verifyEquivalentType(Type typeA, Type typeB); LogicalResult verifySameOperandsAndResultEncoding(Operation *op, bool allowTensorPointerType = false); diff --git a/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/include/triton/Dialect/Triton/IR/TritonInterfaces.td index f51cca0bc2..a9188cbf63 100644 --- a/include/triton/Dialect/Triton/IR/TritonInterfaces.td +++ b/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -2,6 +2,7 @@ #define TRITON_INTERFACES include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; @@ -13,4 +14,17 @@ def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAn def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">; def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">; +// A trait equivalent to InferTypeOpAdaptor, but that checks for structural +// equivalence of the layouts of the result rather than just layout equality. +def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{ + static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { + if (lhs.size() != rhs.size()) + return false; + return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) { + auto [lhs, rhs] = tup; + return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs)); + }); + } +}]>; + #endif // TRITON_INTERFACES diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 6d72beaca6..c638a4fd95 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [ def TT_TransOp : TT_Op<"trans", [Pure, TransposeOpInterface, - InferTypeOpAdaptorWithIsCompatible, + InferTypeOpWithLayoutEquivalence, SameOperandsAndResultElementType]> { let summary = "rearrange the dimensions of a tensor"; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index d8ed0492ce..991a919608 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -235,26 +235,6 @@ LogicalResult TransOp::inferReturnTypes( return success(); } -bool TransOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { - assert(lhs.size() == rhs.size()); - assert(lhs.size() == 1); - auto lhsType = cast(lhs[0]); - auto rhsType = cast(rhs[0]); - - if (lhsType.getShape() != rhsType.getShape()) - return false; - - auto lhsEnc = lhsType.getEncoding(); - auto rhsEnc = rhsType.getEncoding(); - // If there's no encoding or the encodings are the same - if (lhsEnc == rhsEnc) - return true; - - return cast(&lhsEnc.getDialect()) - ->verifyLayoutsAreEqual(lhsType.getShape(), lhsEnc, rhsEnc, {}) - .succeeded(); -} - //-- DotOp -- LogicalResult DotOp::inferReturnTypes(MLIRContext *context, std::optional location, diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index 690826f4ef..a38e37bb07 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -3,12 +3,33 @@ #include #include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; +LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) { + auto tensorTypeA = dyn_cast(typeA); + auto tensorTypeB = dyn_cast(typeB); + if (!(bool(tensorTypeA) && bool(tensorTypeB))) + return typeA == typeB ? success() : failure(); + auto encodingA = tensorTypeA.getEncoding(); + auto encodingB = tensorTypeB.getEncoding(); + auto shapeA = tensorTypeA.getShape(); + auto shapeB = tensorTypeB.getShape(); + if (shapeA != shapeB) + return failure(); + + // If there's no encoding or the encodings are the same + if (encodingA == encodingB) + return success(); + + return cast(&encodingA.getDialect()) + ->verifyLayoutsAreEqual(shapeA, encodingA, encodingB, {}); +} + static LogicalResult verifySameEncoding(Type typeA, Type typeB, bool allowTensorPointerType) { // TODO(Keren): the allowTensorPointerType argument is a hack to allow. From 924468ecbb0d9fe22d5f355aa0eb39a69dba17c2 Mon Sep 17 00:00:00 2001 From: sfzhu93 <42506672+sfzhu93@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:14:18 -0800 Subject: [PATCH 4/5] [Frontend][Diagnostics] Improve emitting diagnostic information (#5581) ### Summary This PR enhances the current implementation for emitting diagnostic remarks by introducing a unified handler in `ir.cc`. This handler manages diagnostic information more effectively and disables the emission of IRs unless explicitly requested by the user. The `MLIR_ENABLE_DIAGNOSTICS` environment variable now controls all diagnostic emission settings, accepting one or more values from `{warnings, remarks, stacktraces, operations}`, separated by commas. Detailed usage instructions are available in the README. ### Background Previously, a new default LLVM `SourceManager` was configured in `nvidia/backend/compiler.py` to support remarks, applied in both `make_ttgir` and `make_llir`. However, a custom handler already existed in `ir.cc`, and a more robust design should extend this handler rather than create a new one. ### Changes - **Unified Handler**: Inspired by LLVM upstream [[PR 117669](https://github.com/llvm/llvm-project/pull/117669)](https://github.com/llvm/llvm-project/pull/117669), this PR implements a similar custom handler that supports various severity levels. The `MLIR_ENABLE_DIAGNOSTICS` environment variable now specifies the severity level: `warnings` for warnings and errors, and `remarks` for remarks, warnings, and errors. - **IR Emission Control**: By default, the MLIR diagnostic API emits IRs, which can clutter error messages or performance remarks. This PR suppresses IR emission unless explicitly enabled by the user, improving the readability of error messages and performance remarks. Users can specify `MLIR_ENABLE_DIAGNOSTICS=remarks,operations` to include IR operations in remarks. - **Stacktraces**: Previously, setting `MLIR_ENABLE_DIAGNOSTICS=1` enabled all diagnostic information with stacktraces. Now, the `stacktraces` parameter specifically enables stacktraces. For example, `MLIR_ENABLE_DIAGNOSTICS=remarks,operations,stacktraces` enables IR operations and stacktraces, displaying all remarks, warnings, and errors. - **Testing**: Updated existing Python tests to verify that combinations of operations and stacktraces are emitted successfully. ### Future Work - With the new handler in place, there is an opportunity to further enhance the readability of existing warnings and remarks. This will be a focus in future updates. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/python/test` for end-to-end tests - [ ] This PR does not need a test. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --- README.md | 10 +- python/src/ir.cc | 134 +++++++++++++++++-------- python/test/unit/test_perf_warning.py | 102 +++++++++++-------- third_party/nvidia/backend/compiler.py | 12 +-- 4 files changed, 161 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 0d5d08dde6..0f8da14ecd 100644 --- a/README.md +++ b/README.md @@ -232,8 +232,14 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi - `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). -- `MLIR_ENABLE_DIAGNOSTICS` enables dumping the stack trace and the related IR operation of diagnostics (e.g., errors and warnings). -- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks. +- `MLIR_ENABLE_DIAGNOSTICS=` controls diagnostic emission in MLIR. + Options are: `warnings`, `remarks`, `stacktraces`, `operations`. + Use comma-separated values to customize output. For example, + `MLIR_ENABLE_DIAGNOSTICS=remarks,operations` enables remarks and IR operations, + while `MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces` enables warnings with + stacktraces. By default, only errors are shown. Setting `warnings` includes + errors and warnings; `remarks` includes errors, warnings, and remarks. +- `MLIR_ENABLE_REMARK` is deprecated. Please use `MLIR_ENABLE_DIAGNOSTICS=remarks`. - `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn. - `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1. - `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage. diff --git a/python/src/ir.cc b/python/src/ir.cc index 53451b706a..b5411dd428 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -140,6 +140,42 @@ class TritonOpBuilder { bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); }; +// Run the pass manager under a source manager diagnostic handler, which +// enables emitted MLIR diagnostics to directly reference Python source +// code. This diagnostic handler supports filtering diagnostic info by +// severity levels. +struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler { + TritonSourceMgrDiagnosticHandler(MLIRContext *ctx, + DiagnosticSeverity minSeverity) + : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { + setHandler([this, minSeverity](Diagnostic &diag) { + auto severity = diag.getSeverity(); + switch (severity) { + case DiagnosticSeverity::Error: + break; + case DiagnosticSeverity::Warning: + if (minSeverity == DiagnosticSeverity::Error) + return success(); + break; + case DiagnosticSeverity::Remark: + if (minSeverity == DiagnosticSeverity::Error || + minSeverity == DiagnosticSeverity::Warning) + return success(); + break; + case DiagnosticSeverity::Note: + // notes are handled somewhere else. + return failure(); + default: + llvm_unreachable("Unknown diagnostic severity"); + } + emitDiagnostic(diag); + return success(); + }); + } + + llvm::SourceMgr sourceMgr; +}; + std::string locationToString(Location loc) { std::string str; llvm::raw_string_ostream os(str); @@ -148,6 +184,23 @@ std::string locationToString(Location loc) { return str; } +// Function to parse a comma-separated string into a vector of C-style strings +llvm::SmallVector +parseCommaSeparatedValues(const std::string &input, + llvm::SmallVector &storage) { + llvm::SmallVector split; + llvm::SmallVector result; + StringRef(input.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + return result; +} + void outputWarning(Location loc, const std::string &msg) { std::string locStr = locationToString(loc); @@ -1691,8 +1744,6 @@ void init_triton_ir(py::module &&m) { .def("enable_debug", [](PassManager &self) { auto *context = self.getContext(); - bool haveDiagnostics = - ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); std::string funcToDump; if (!haveDump) { @@ -1700,18 +1751,8 @@ void init_triton_ir(py::module &&m) { if (!funcToDump.empty()) haveDump = true; } - if (haveDiagnostics || haveDump) { - context->disableMultithreading(); - } - if (haveDiagnostics) { - context->printOpOnDiagnostic(true); - context->printStackTraceOnDiagnostic(true); - context->getDiagEngine().registerHandler([](Diagnostic &diag) { - llvm::outs() << diag << "\n"; - return success(); - }); - } if (haveDump) { + context->disableMultithreading(); auto printingFlags = OpPrintingFlags(); printingFlags.elideLargeElementsAttrs(16); printingFlags.enableDebugInfo(); @@ -1741,6 +1782,8 @@ void init_triton_ir(py::module &&m) { // TODO: maybe dump module to file and print error for better // diagnostics + auto *context = mod.getContext(); + auto reproducerPath = triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); if (!reproducerPath.empty()) { @@ -1752,7 +1795,7 @@ void init_triton_ir(py::module &&m) { makeReproducer(anchorName, passes, op, reproducerPath); // But if the pass manager crashes, attempt to generate a local // reproducer instead. - mod.getContext()->disableMultithreading(); + context->disableMultithreading(); self.enableCrashReproducerGeneration(reproducerPath, /*genLocalReproducer=*/true); } @@ -1763,20 +1806,9 @@ void init_triton_ir(py::module &&m) { if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); !debugOnly.empty()) { - llvm::SmallVector split; llvm::SmallVector storage; - llvm::SmallVector debugTypes; - - StringRef(debugOnly.c_str()).split(split, ','); - llvm::transform(split, std::back_inserter(debugTypes), - [&storage](StringRef str) { - // StringRefs are not always null-terminated. - // The purpose for this storage pattern is to - // produce a collection of C-strings that are. - storage.push_back(str.str()); - return storage.back().c_str(); - }); - + llvm::SmallVector debugTypes = + parseCommaSeparatedValues(debugOnly, storage); ::llvm::DebugFlag = true; using namespace llvm; setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); @@ -1787,25 +1819,41 @@ void init_triton_ir(py::module &&m) { self.enableTiming(); } - // Run the pass manager under a source manager diagnostic handler, which - // enables emitted MLIR diagnostics to directly reference Python source - // code. This diagnostic handler will only filter for errors. - struct SourceMgrErrorDiagnosticHandler - : public SourceMgrDiagnosticHandler { - SourceMgrErrorDiagnosticHandler(MLIRContext *ctx) - : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { - setHandler([this](Diagnostic &diag) { - if (diag.getSeverity() != DiagnosticSeverity::Error) - return failure(); - emitDiagnostic(diag); - return success(); - }); + // setting up diagnostics + bool showOperations = false, showStacktraces = false, + showRemarks = false, showWarnings = false; + + if (auto enableDiagnostics = + triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS"); + !enableDiagnostics.empty()) { + llvm::SmallVector storage; + parseCommaSeparatedValues(enableDiagnostics, storage); + for (auto &str : storage) { + if (str == "warnings") { + showWarnings = true; + } else if (str == "remarks") { + showRemarks = true; + } else if (str == "stacktraces") { + showStacktraces = true; + } else if (str == "operations") { + showOperations = true; + } + // we show errors by default, so no need to set it } + } - llvm::SourceMgr sourceMgr; - }; - SourceMgrErrorDiagnosticHandler diagHandler(mod.getContext()); + DiagnosticSeverity minSeverity = showWarnings + ? DiagnosticSeverity::Warning + : DiagnosticSeverity::Error; + minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity; + TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity); + + context->printOpOnDiagnostic(showOperations); + context->printStackTraceOnDiagnostic(showStacktraces); + if (showStacktraces) { + context->disableMultithreading(); + } if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); }); diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index bdf45b0210..86bebdd71a 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -8,16 +8,12 @@ @contextmanager -def enable_remark_context(): +def enable_diagnostics_context(value): try: - os.environ["MLIR_ENABLE_REMARK"] = "1" + os.environ["MLIR_ENABLE_DIAGNOSTICS"] = value yield finally: - os.environ["MLIR_ENABLE_REMARK"] = "0" - - -def is_perf_warning_enabled(): - return os.environ.get("MLIR_ENABLE_REMARK", "0") == "1" + os.environ["MLIR_ENABLE_DIAGNOSTICS"] = "" def is_cuda(): @@ -74,29 +70,39 @@ def matmul_kernel( c = tl.dot(a, b) tl.store(c_block_ptr, c) - with enable_remark_context(): - triton.compile( - triton.compiler.ASTSource( - fn=matmul_kernel, - signature={ - "a_ptr": "*fp32", - "b_ptr": "*fp32", - "c_ptr": "*fp32", - "M": "i32", - "N": "i32", - "K": "i32", - "stride_am": "i32", - "stride_ak": "i32", - "stride_bk": "i32", - "stride_bn": "i32", - "stride_cm": "i32", - "stride_cn": "i32", - }, - constexprs={}, - )) + signature = { + "a_ptr": "*fp32", + "b_ptr": "*fp32", + "c_ptr": "*fp32", + "M": "i32", + "N": "i32", + "K": "i32", + "stride_am": "i32", + "stride_ak": "i32", + "stride_bk": "i32", + "stride_bn": "i32", + "stride_cm": "i32", + "stride_cn": "i32", + } + with enable_diagnostics_context('remarks'): + triton.compile(triton.compiler.ASTSource( + fn=matmul_kernel, + signature=signature, + constexprs={}, + )) captured = capfd.readouterr() - assert ("remark: Warning: can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark" + assert ("can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark" + assert "note: see current operation:" not in captured.err + + with enable_diagnostics_context('remarks,operations,stacktraces'): + triton.compile(triton.compiler.ASTSource( + fn=matmul_kernel, + signature=signature, + constexprs={}, + )) + captured = capfd.readouterr() + assert "note: diagnostic emitted with trace:" in captured.err assert "note: see current operation:" in captured.err @@ -126,25 +132,39 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) tl.store(out_ptr0 + (x4), tmp22, None) XBLOCK = 1024 - with enable_remark_context(): + + astsource_args = { + "fn": ldst_vec, + "signature": { + "in_ptr0": "*i64", + "in_ptr1": "*i64", + "in_ptr2": "*fp16", + "in_ptr3": "*fp32", + "out_ptr0": "*fp16", + "XBLOCK": "constexpr", + }, + "constexprs": {"XBLOCK": XBLOCK}, + } + + with enable_diagnostics_context('remarks'): triton.compile( - triton.compiler.ASTSource( - fn=ldst_vec, - signature={ - "in_ptr0": "*i64", - "in_ptr1": "*i64", - "in_ptr2": "*fp16", - "in_ptr3": "*fp32", - "out_ptr0": "*fp16", - "XBLOCK": "constexpr", - }, - constexprs={"XBLOCK": XBLOCK}, - ), + triton.compiler.ASTSource(**astsource_args), options={"num_warps": 1}, ) _, err = capfd.readouterr() assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + assert "note: see current operation:" not in err + + with enable_diagnostics_context('remarks,operations,stacktraces'): + triton.compile( + triton.compiler.ASTSource(**astsource_args), + options={"num_warps": 1}, + ) + + _, err = capfd.readouterr() + assert "note: see current operation:" in err + assert "note: diagnostic emitted with trace:" in err def test_remark_swp_op_before_operands(capfd, fresh_triton_cache): diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 1fcd7dc5b3..7563b7515b 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -238,12 +238,6 @@ def make_ttgir(mod, metadata, opt, capability): cluster_info.clusterDimX = opt.cluster_dims[0] cluster_info.clusterDimY = opt.cluster_dims[1] cluster_info.clusterDimZ = opt.cluster_dims[2] - # Set up Diagnostic - if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": - srcMgr = llvm.source_mgr() - _ = ir.source_mgr_diag(srcMgr, mod.context) - mod.context.printOpOnDiagnostic(True) - # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug() passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) @@ -299,11 +293,7 @@ def make_llir(self, src, metadata, options, capability): # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() - # Set up Diagnostic - if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": - srcMgr = llvm.source_mgr() - _ = ir.source_mgr_diag(srcMgr, mod.context) - mod.context.printOpOnDiagnostic(True) + nvidia.passes.ttnvgpuir.add_lower_mma(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.convert.add_scf_to_cf(pm) From f47cc3eaaa11cf87ffd93127a5d57eed907bdcd5 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 29 Jan 2025 21:24:47 -0500 Subject: [PATCH 5/5] [PROTON] Reworked the mechanism for finding libraries for profiling backends. (#5751) We should be able to use the default CUPTI path when Triton is installed, whether in development mode or within site packages. Additionally, we resolved incorrect architecture-to-package name mappings in setup.py. --- python/setup.py | 9 ++---- third_party/proton/CMakeLists.txt | 5 ---- third_party/proton/csrc/Proton.cpp | 5 ++-- .../proton/csrc/include/Driver/Dispatch.h | 29 +++++++++---------- .../proton/csrc/include/Driver/GPU/CuptiApi.h | 3 ++ .../csrc/include/Profiler/GPUProfiler.h | 6 ++++ .../proton/csrc/include/Session/Session.h | 2 ++ .../proton/csrc/lib/Driver/GPU/CuptiApi.cpp | 10 ++----- .../csrc/lib/Profiler/Cupti/CuptiProfiler.cpp | 3 ++ .../Profiler/RocTracer/RoctracerProfiler.cpp | 1 + .../proton/csrc/lib/Session/Session.cpp | 24 ++++++++------- third_party/proton/proton/profile.py | 17 ++++++++++- third_party/proton/test/test_lib.py | 4 +-- 13 files changed, 69 insertions(+), 49 deletions(-) diff --git a/python/setup.py b/python/setup.py index 8109b5e235..241c714e36 100644 --- a/python/setup.py +++ b/python/setup.py @@ -315,7 +315,8 @@ def download_and_copy(name, src_func, dst_path, variable, version, url_func): base_dir = os.path.dirname(__file__) system = platform.system() arch = platform.machine() - arch = {"arm64": "aarch64"}.get(arch, arch) + # NOTE: This might be wrong for jetson if both grace chips and jetson chips return aarch64 + arch = {"arm64": "sbsa", "aarch64": "sbsa"}.get(arch, arch) supported = {"Linux": "linux", "Darwin": "linux"} url = url_func(supported[system], arch, version) src_path = src_func(supported[system], arch, version) @@ -407,11 +408,7 @@ def get_proton_cmake_args(self): if cupti_include_dir == "": cupti_include_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "include") cmake_args += ["-DCUPTI_INCLUDE_DIR=" + cupti_include_dir] - cupti_lib_dir = get_env_with_keys(["TRITON_CUPTI_LIB_PATH"]) - if cupti_lib_dir == "": - cupti_lib_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "lib", "cupti") - cmake_args += ["-DCUPTI_LIB_DIR=" + cupti_lib_dir] - roctracer_include_dir = get_env_with_keys(["ROCTRACER_INCLUDE_PATH"]) + roctracer_include_dir = get_env_with_keys(["TRITON_ROCTRACER_INCLUDE_PATH"]) if roctracer_include_dir == "": roctracer_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include") cmake_args += ["-DROCTRACER_INCLUDE_DIR=" + roctracer_include_dir] diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index e0fafb43a9..e7ef143120 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -29,11 +29,6 @@ if(APPLE) set(PROTON_PYTHON_LDFLAGS "-undefined dynamic_lookup") endif() -if(DEFINED CUPTI_LIB_DIR) - message(STATUS "CUPTI lib directory: ${CUPTI_LIB_DIR}") - add_compile_definitions(CUPTI_LIB_DIR=${CUPTI_LIB_DIR}) -endif() - include_directories(${CUPTI_INCLUDE_DIR}) include_directories(SYSTEM ${ROCTRACER_INCLUDE_DIR}) target_compile_definitions(proton PRIVATE __HIP_PLATFORM_AMD__) diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index 6e10792587..b4840cca9e 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -15,9 +15,10 @@ void initProton(pybind11::module &&m) { m.def("start", [](const std::string &path, const std::string &contextSourceName, - const std::string &dataName, const std::string &profilerName) { + const std::string &dataName, const std::string &profilerName, + const std::string &profilerPath) { auto sessionId = SessionManager::instance().addSession( - path, profilerName, contextSourceName, dataName); + path, profilerName, profilerPath, contextSourceName, dataName); SessionManager::instance().activateSession(sessionId); return sessionId; }); diff --git a/third_party/proton/csrc/include/Driver/Dispatch.h b/third_party/proton/csrc/include/Driver/Dispatch.h index 1d8ec017cd..81f6b5b329 100644 --- a/third_party/proton/csrc/include/Driver/Dispatch.h +++ b/third_party/proton/csrc/include/Driver/Dispatch.h @@ -44,13 +44,13 @@ namespace proton { struct ExternLibBase { using RetType = int; // Generic type, can be overridden in derived structs - static constexpr const char *name = ""; // Placeholder - static constexpr const char *defaultDir = ""; // Placeholder - static constexpr RetType success = 0; // Placeholder + static constexpr const char *name = ""; // Placeholder + static constexpr RetType success = 0; // Placeholder ExternLibBase() = delete; ExternLibBase(const ExternLibBase &) = delete; ExternLibBase &operator=(const ExternLibBase &) = delete; static inline void *lib{nullptr}; + static inline std::string defaultDir{""}; }; template class Dispatch { @@ -59,25 +59,24 @@ template class Dispatch { static void init(const char *name, void **lib) { if (*lib == nullptr) { - // First reuse the existing handle - *lib = dlopen(name, RTLD_NOLOAD); - } - if (*lib == nullptr) { - // If not found, try to load it from LD_LIBRARY_PATH - *lib = dlopen(name, RTLD_LOCAL | RTLD_LAZY); - } - if (*lib == nullptr) { - // If still not found, try to load it from the default path + // If not found, try to load it from the default path auto dir = std::string(ExternLib::defaultDir); if (dir.length() > 0) { auto fullPath = dir + "/" + name; *lib = dlopen(fullPath.c_str(), RTLD_LOCAL | RTLD_LAZY); + } else { + // Only if the default path is not set, we try to load it from the + // system. + // First reuse the existing handle + *lib = dlopen(name, RTLD_NOLOAD); + if (*lib == nullptr) { + // If not found, try to load it from LD_LIBRARY_PATH + *lib = dlopen(name, RTLD_LOCAL | RTLD_LAZY); + } } } if (*lib == nullptr) { - throw std::runtime_error("Could not find `" + std::string(name) + - "`. Make sure it is in your " - "LD_LIBRARY_PATH."); + throw std::runtime_error("Could not load `" + std::string(name) + "`"); } } diff --git a/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h b/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h index 495964923e..28c12f7c43 100644 --- a/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h +++ b/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h @@ -3,6 +3,7 @@ #include "cupti.h" #include "cupti_pcsampling.h" +#include namespace proton { @@ -106,6 +107,8 @@ CUptiResult pcSamplingStart(CUpti_PCSamplingStartParams *pParams); template CUptiResult pcSamplingStop(CUpti_PCSamplingStopParams *pParams); +void setLibPath(const std::string &path); + } // namespace cupti } // namespace proton diff --git a/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/proton/csrc/include/Profiler/GPUProfiler.h index bb7a063aa2..a12889278d 100644 --- a/third_party/proton/csrc/include/Profiler/GPUProfiler.h +++ b/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -41,6 +41,11 @@ class GPUProfiler : public Profiler, } bool isPCSamplingEnabled() const { return pcSamplingEnabled; } + ConcreteProfilerT &setLibPath(const std::string &libPath) { + pImpl->setLibPath(libPath); + return dynamic_cast(*this); + } + protected: // OpInterface void startOp(const Scope &scope) override { @@ -136,6 +141,7 @@ class GPUProfiler : public Profiler, : profiler(profiler) {} virtual ~GPUProfilerPimplInterface() = default; + virtual void setLibPath(const std::string &libPath) = 0; virtual void doStart() = 0; virtual void doFlush() = 0; virtual void doStop() = 0; diff --git a/third_party/proton/csrc/include/Session/Session.h b/third_party/proton/csrc/include/Session/Session.h index 8bbf34234e..88b9ab3496 100644 --- a/third_party/proton/csrc/include/Session/Session.h +++ b/third_party/proton/csrc/include/Session/Session.h @@ -72,6 +72,7 @@ class SessionManager : public Singleton { ~SessionManager() = default; size_t addSession(const std::string &path, const std::string &profilerName, + const std::string &profilerPath, const std::string &contextSourceName, const std::string &dataName); @@ -103,6 +104,7 @@ class SessionManager : public Singleton { private: std::unique_ptr makeSession(size_t id, const std::string &path, const std::string &profilerName, + const std::string &profilerPath, const std::string &contextSourceName, const std::string &dataName); diff --git a/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp index 2c399d31c7..f86db8de21 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp @@ -6,16 +6,10 @@ namespace proton { namespace cupti { -#define STRINGIFY(x) #x -#define TOSTRING(x) STRINGIFY(x) struct ExternLibCupti : public ExternLibBase { using RetType = CUptiResult; static constexpr const char *name = "libcupti.so"; -#ifdef CUPTI_LIB_DIR - static constexpr const char *defaultDir = TOSTRING(CUPTI_LIB_DIR); -#else - static constexpr const char *defaultDir = ""; -#endif + static inline std::string defaultDir = ""; static constexpr RetType success = CUPTI_SUCCESS; static void *lib; }; @@ -116,6 +110,8 @@ DEFINE_DISPATCH(ExternLibCupti, pcSamplingStart, cuptiPCSamplingStart, DEFINE_DISPATCH(ExternLibCupti, pcSamplingStop, cuptiPCSamplingStop, CUpti_PCSamplingStopParams *); +void setLibPath(const std::string &path) { ExternLibCupti::defaultDir = path; } + } // namespace cupti } // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp index 378441cf63..2c60b536d4 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -198,6 +198,9 @@ struct CuptiProfiler::CuptiProfilerPimpl : GPUProfiler::GPUProfilerPimplInterface(profiler) {} virtual ~CuptiProfilerPimpl() = default; + void setLibPath(const std::string &libPath) override { + cupti::setLibPath(libPath); + } void doStart() override; void doFlush() override; void doStop() override; diff --git a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp index 317bdc5e32..f5d66907ed 100644 --- a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -189,6 +189,7 @@ struct RoctracerProfiler::RoctracerProfilerPimpl : GPUProfiler::GPUProfilerPimplInterface(profiler) {} virtual ~RoctracerProfilerPimpl() = default; + void setLibPath(const std::string &libPath) override {} void doStart() override; void doFlush() override; void doStop() override; diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 44cd4c5e22..26f0fbf891 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -9,17 +9,17 @@ namespace proton { namespace { -Profiler *getProfiler(const std::string &profilerName) { - if (proton::toLower(profilerName) == "cupti") { - return &CuptiProfiler::instance(); +Profiler *getProfiler(const std::string &name, const std::string &path) { + if (proton::toLower(name) == "cupti") { + return &CuptiProfiler::instance().setLibPath(path); } - if (proton::toLower(profilerName) == "cupti_pcsampling") { - return &CuptiProfiler::instance().enablePCSampling(); + if (proton::toLower(name) == "cupti_pcsampling") { + return &CuptiProfiler::instance().setLibPath(path).enablePCSampling(); } - if (proton::toLower(profilerName) == "roctracer") { + if (proton::toLower(name) == "roctracer") { return &RoctracerProfiler::instance(); } - throw std::runtime_error("Unknown profiler: " + profilerName); + throw std::runtime_error("Unknown profiler: " + name); } std::unique_ptr makeData(const std::string &dataName, @@ -71,8 +71,9 @@ void Session::finalize(OutputFormat outputFormat) { std::unique_ptr SessionManager::makeSession( size_t id, const std::string &path, const std::string &profilerName, - const std::string &contextSourceName, const std::string &dataName) { - auto profiler = getProfiler(profilerName); + const std::string &profilerPath, const std::string &contextSourceName, + const std::string &dataName) { + auto profiler = getProfiler(profilerName, profilerPath); auto contextSource = makeContextSource(contextSourceName); auto data = makeData(dataName, path, contextSource.get()); auto *session = new Session(id, path, profiler, std::move(contextSource), @@ -139,6 +140,7 @@ void SessionManager::removeSession(size_t sessionId) { size_t SessionManager::addSession(const std::string &path, const std::string &profilerName, + const std::string &profilerPath, const std::string &contextSourceName, const std::string &dataName) { std::lock_guard lock(mutex); @@ -149,8 +151,8 @@ size_t SessionManager::addSession(const std::string &path, } auto sessionId = nextSessionId++; sessionPaths[path] = sessionId; - sessions[sessionId] = - makeSession(sessionId, path, profilerName, contextSourceName, dataName); + sessions[sessionId] = makeSession(sessionId, path, profilerName, profilerPath, + contextSourceName, dataName); return sessionId; } diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 575c85b0ca..5ee01f7b47 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -1,6 +1,7 @@ import functools import triton import os +import pathlib from triton._C.libproton import proton as libproton from .hook import register_triton_hook, unregister_triton_hook @@ -20,6 +21,18 @@ def _select_backend() -> str: raise ValueError("No backend is available for the current target.") +def _get_backend_default_path(backend: str) -> str: + lib_path = "" + if backend == "cupti": + # First try to get the path from the environment variable that overrides the default path + lib_path = os.getenv("TRITON_CUPTI_LIB_PATH", None) + if lib_path is None: + # Get the default path for the cupti backend, + # which is the most compatible with the current CUPTI header file triton is compiled with + lib_path = str(pathlib.Path(__file__).parent.parent.absolute() / "backends" / "nvidia" / "lib" / "cupti") + return lib_path + + def _check_env(backend: str) -> None: if backend == "roctracer": hip_device_envs = ["HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES"] @@ -79,10 +92,12 @@ def start( _check_env(backend) + backend_path = _get_backend_default_path(backend) + set_profiling_on() if hook and hook == "triton": register_triton_hook() - return libproton.start(name, context, data, backend) + return libproton.start(name, context, data, backend, backend_path) def activate(session: Optional[int] = None) -> None: diff --git a/third_party/proton/test/test_lib.py b/third_party/proton/test/test_lib.py index 4a8313660b..c1936c73d9 100644 --- a/third_party/proton/test/test_lib.py +++ b/third_party/proton/test/test_lib.py @@ -32,7 +32,7 @@ def test_op(): def test_session(tmp_path: pathlib.Path): temp_file = tmp_path / "test_session.hatchet" - session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "") libproton.deactivate(session_id) libproton.activate(session_id) libproton.finalize(session_id, "hatchet") @@ -42,7 +42,7 @@ def test_session(tmp_path: pathlib.Path): def test_add_metrics(tmp_path: pathlib.Path): temp_file = tmp_path / "test_add_metrics.hatchet" - libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend()) + libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "") id1 = libproton.record_scope() libproton.enter_scope(id1, "one") libproton.add_metrics(id1, {"a": 1.0, "b": 2.0})