Skip to content

Commit

Permalink
Merge OpenAI Triton commit f47cc3e (#3319)
Browse files Browse the repository at this point in the history
This PR change the Triton base from
9a49104 to
f47cc3e (Jan 29).
Pass rate: 98.19%

Please do not squash and merge this PR.
  • Loading branch information
anmyachev authored Jan 31, 2025
2 parents b0ddc4b + 6ce6d5b commit ccf97fd
Show file tree
Hide file tree
Showing 28 changed files with 518 additions and 297 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
- `MLIR_ENABLE_DIAGNOSTICS` enables dumping the stack trace and the related IR operation of diagnostics (e.g., errors and warnings).
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
- `MLIR_ENABLE_DIAGNOSTICS=<comma-separated>` controls diagnostic emission in MLIR.
Options are: `warnings`, `remarks`, `stacktraces`, `operations`.
Use comma-separated values to customize output. For example,
`MLIR_ENABLE_DIAGNOSTICS=remarks,operations` enables remarks and IR operations,
while `MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces` enables warnings with
stacktraces. By default, only errors are shown. Setting `warnings` includes
errors and warnings; `remarks` includes errors, warnings, and remarks.
- `MLIR_ENABLE_REMARK` is deprecated. Please use `MLIR_ENABLE_DIAGNOSTICS=remarks`.
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1.
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
Expand Down
35 changes: 35 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/FMADotUtility.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H
#define TRITON_CONVERSION_FMA_DOT_UTILITY_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir::triton::gpu {

/// Abstract interface for scalar multiplication of Value vectors.
///
/// Enable generation of hardware specific code in different backends.
class FMAVectorMultiplier {
public:
/// \returns scalar product of two arrays, plus c: a·b + c
virtual Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
Value c) = 0;

virtual ~FMAVectorMultiplier() = default;
};

/// Implements a framework for FMA dot conversion to llvm.
///
/// This function implements architecture independent part of FMA dot
/// conversion and calls "multiplier" object, which is defined by caller
/// and implements architecture dependant part of conversion.
LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FMAVectorMultiplier &multiplier);

} // namespace mlir::triton::gpu

#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "triton/Dialect/Triton/IR/Types.h"

Expand All @@ -27,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op);

LogicalResult verifySameOperandsEncoding(Operation *op,
bool allowTensorPointerType = false);

LogicalResult verifyEquivalentType(Type typeA, Type typeB);
LogicalResult
verifySameOperandsAndResultEncoding(Operation *op,
bool allowTensorPointerType = false);
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_INTERFACES

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
Expand All @@ -13,4 +14,17 @@ def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAn
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;

// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
// equivalence of the layouts of the result rather than just layout equality.
def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
if (lhs.size() != rhs.size())
return false;
return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) {
auto [lhs, rhs] = tup;
return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs));
});
}
}]>;

#endif // TRITON_INTERFACES
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [

def TT_TransOp : TT_Op<"trans", [Pure,
TransposeOpInterface,
InferTypeOpAdaptorWithIsCompatible,
InferTypeOpWithLayoutEquivalence,
SameOperandsAndResultElementType]> {

let summary = "rearrange the dimensions of a tensor";
Expand Down
6 changes: 6 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,

std::tie(scratchConfig.inVec, scratchConfig.outVec) =
getScratchCvtInOutVecLengths(srcTy, dstTy);
// We can't write a longer vector than the shape of shared memory.
// This shape might be smaller than the tensor shape in case we decided to
// do the conversion in multiple iterations.
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);

// No padding is required if the tensor is 1-D, or if all dimensions except
// the first accessed dimension have a size of 1.
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
DotOpToLLVM/FMA.cpp
DotOpToLLVM/FMADotUtility.cpp
AllocateSharedMemory.cpp
AssertOpToLLVM.cpp
ControlFlowOpToLLVM.cpp
Expand Down
152 changes: 23 additions & 129 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
@@ -1,144 +1,38 @@
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

using namespace mlir;
using namespace mlir::triton;
using namespace ::mlir::triton::gpu;

using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;

/// \brief spatial position of repetition and register of a given value
struct OperandValueKey {
unsigned bRepIdx, nonKRepIdx;
unsigned bIdx, nonKIdx, kIdx;

bool operator==(const OperandValueKey &other) const {
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
kIdx == other.kIdx);
}
};

template <> struct std::hash<OperandValueKey> {
std::size_t operator()(const OperandValueKey &k) const {
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
k.kIdx);
namespace {
class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
OpBuilder &builder;
Location loc;

public:
GenericFMAVectorMultiplier(OpBuilder &builder, Location loc)
: builder(builder), loc(loc) {}

Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
Value c) override {
auto K = a.size();
assert(b.size() == K);
Value accum = c;
for (auto [aElem, bElem] : llvm::zip(a, b))
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
return accum;
}
};

using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;

static ValueTableFMA getValueTableFromStructFMA(
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
ValueTableFMA res;
auto elems = unpackLLElements(loc, val, rewriter);
assert(perRepShape.size() == 3);
auto numElemsRep = product(perRepShape);
assert(elems.size() == numElemsRep * product(repetitions));
assert(kDim == 1 || kDim == 2);
assert(nonKDim == 1 || nonKDim == 2);
const unsigned bDim = 0;
} // namespace

for (unsigned idx = 0; idx < elems.size(); ++idx) {
auto inRepLinearIdx = idx % numElemsRep;
auto repLinearIdx = idx / numElemsRep;
auto inRepSpatialIdx =
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
auto repSpatialIdx =
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
inRepSpatialIdx[kDim]};
res[key] = elems[idx];
}
return res;
}

LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();

auto A = op.getA();
auto D = op.getResult();

auto aTensorTy = cast<RankedTensorType>(A.getType());
auto dTensorTy = cast<RankedTensorType>(D.getType());

SmallVector<int64_t> aShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
auto dShapePerCTA =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));

BlockedEncodingAttr dLayout =
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
// TODO process A and B operand separately
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);

Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto numElemsPerThread = product(sizePerThread);
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));

unsigned K = aShapePerCTA[2];

unsigned threadTileShape[3];
unsigned repetitions[3];
for (int i = 0; i < 3; ++i) {
repetitions[i] =
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
}

auto has = getValueTableFromStructFMA(
llA, {sizePerThread[0], sizePerThread[1], K},
{repetitions[0], repetitions[1], 1},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
auto hbs = getValueTableFromStructFMA(
llB, {sizePerThread[0], K, sizePerThread[2]},
{repetitions[0], 1, repetitions[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);

SmallVector<Value> acc = cc;

for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned m = 0; m < sizePerThread[1]; ++m)
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearInRepIdx =
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
unsigned linearRepIdx =
linearize(multiDimRepIdx, repetitions, repOrder);
unsigned linearAccumIdx =
linearInRepIdx + linearRepIdx * numElemsPerThread;
for (unsigned k = 0; k < K; ++k) {
auto aOp = has[{bRep, mRep, b, m, k}];
auto bOp = hbs[{bRep, nRep, b, n, k}];
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, aOp, bOp, acc[linearAccumIdx]);
}
}

auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
rewriter.replaceOp(op, res);

return success();
GenericFMAVectorMultiplier multiplier(rewriter, loc);
return parametricConvertFMADot(op, adaptor, typeConverter, rewriter,
multiplier);
}
Loading

0 comments on commit ccf97fd

Please sign in to comment.