Skip to content

Commit

Permalink
[LAYOUTS] Implement generically getUniqueContigPerThread (#5840)
Browse files Browse the repository at this point in the history
This allows vectorisation on global loads and smem in some cases we
didn't use it before, as we now compute the order of the elements
looking at the actual LinearLayout associated to the given shape of the
tensor, which is quite neat.

We end up touching a few things in the Scan lowering as BlockedLayouts
when converted to LinearEncodings may not have the same order on
elems/threads/warps. This is a feature, not a bug, as it takes us closer
to supporting arbitrary LinearEncodings within the tt.scan op.
  • Loading branch information
lezcano authored Feb 13, 2025
1 parent e7072a3 commit 06941f4
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 187 deletions.
28 changes: 19 additions & 9 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,26 @@ class ScanLoweringHelper {
explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
srcShape = firstTy.getShape();
srcEncoding = firstTy.getEncoding();
legacyEncoding = firstTy.getEncoding();
srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape);
srcElementTypes = op.getElementTypes();
// The codegen does not support different element/thread/warp order so
// we choose one a priori. We choose that of the blocked encoding.
// When we generalise this code to other layouts we'll probably need to
// get rid of all this logic and the *Stride auxiliary methods
// and replace them by transposes and reshapes on the LinearLayout
if (auto blockedEncoding =
dyn_cast<triton::gpu::BlockedEncodingAttr>(legacyEncoding)) {
order = llvm::to_vector(blockedEncoding.getOrder());
} else {
order = srcEncoding.getOrder();
}

for (const auto &t : op.getInputTypes()) {
if (t.getShape() != srcShape) {
op.emitError() << "shape mismatch";
}
if (t.getEncoding() != srcEncoding) {
if (t.getEncoding() != legacyEncoding) {
op.emitError() << "encoding mismatch";
}
}
Expand All @@ -111,12 +123,8 @@ class ScanLoweringHelper {
unsigned getNonAxisNumThreadsPerWarp();
// Return the flat numbers of threads computing independent scan results.
unsigned getNonAxisNumThreadsPerCTA();
// Return the number of warps per CTA along axis dim.
unsigned getAxisNumWarps();
// Return the number of warps per CTA along axis dim with unique data.
unsigned getAxisNumWarpsWithUniqueData();
// Return the number of threads per warp along axis dim.
unsigned getAxisNumThreadsPerWarp();
// Return the number of threads per warp along axis dim with unique data.
unsigned getAxisNumThreadsPerWarpWithUniqueData();
// Return the number of blocks along axis dim.
Expand All @@ -139,18 +147,20 @@ class ScanLoweringHelper {
Location getLoc() { return scanOp.getLoc(); }
unsigned getAxis() { return scanOp.getAxis(); }
bool getReverse() { return scanOp.getReverse(); }
triton::gpu::BlockedEncodingAttr getEncoding();
triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; }
llvm::ArrayRef<int64_t> getShape() { return srcShape; }
unsigned getNumOperands() { return scanOp.getNumOperands(); }
SmallVector<Type> getElementTypes() { return srcElementTypes; }
Attribute getSrcLayout() { return srcEncoding; }
SmallVector<unsigned> getOrder() { return order; }
Region &getCombineOp();

private:
triton::ScanOp scanOp;
Attribute srcEncoding;
triton::gpu::LinearEncodingAttr srcEncoding;
Attribute legacyEncoding;
llvm::ArrayRef<int64_t> srcShape;
SmallVector<Type> srcElementTypes;
SmallVector<unsigned> order;
};

// Helper class for lowering `tt.gather` operations. This class shares lowering
Expand Down
33 changes: 24 additions & 9 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace triton {

// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;
// Max shmem LDS/STS instruction in bits
constexpr int kMaxShmemVecBitLength = 128;

static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand Down Expand Up @@ -79,15 +81,17 @@ std::pair<unsigned, unsigned>
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
const auto &inOrd = gpu::getOrder(srcLayout);
const auto &outOrd = gpu::getOrder(dstLayout);

auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape());
auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape());
auto inOrd = srcLinAttr.getOrder();
auto outOrd = dstLinAttr.getOrder();

unsigned rank = srcTy.getRank();

unsigned srcContigPerThread =
gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
unsigned dstContigPerThread =
gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]];
unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]];
// TODO: Fix the legacy issue that outOrd[0] == 0 always means
// that we cannot do vectorization.
unsigned innerDim = rank - 1;
unsigned inVec = outOrd[0] != innerDim ? 1
Expand Down Expand Up @@ -117,8 +121,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
Attribute dstLayout = dstTy.getEncoding();

assert(cvtNeedsSharedMemory(srcTy, dstTy));

const auto &outOrd = gpu::getOrder(dstLayout);
auto outOrd = gpu::toLinearEncoding(dstLayout, dstTy.getShape()).getOrder();
scratchConfig.order = outOrd;

std::tie(scratchConfig.inVec, scratchConfig.outVec) =
Expand All @@ -129,6 +132,18 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
// is the max vectorisation
auto inBitWidth = isa<PointerType>(srcTy.getElementType())
? kPtrBitWidth
: srcTy.getElementTypeBitWidth();
auto outBitWidth = isa<PointerType>(dstTy.getElementType())
? kPtrBitWidth
: dstTy.getElementTypeBitWidth();
scratchConfig.inVec =
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
scratchConfig.outVec =
std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth);

// No padding is required if the tensor is 1-D, or if all dimensions except
// the first accessed dimension have a size of 1.
Expand Down
22 changes: 13 additions & 9 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1222,15 +1222,16 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
// FIXME: This is not as good as it could be, as we don't need to restrict
// the analysis to one dimension. We should determine contiguity on the
// flattenOuts() layout
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto order = linAttr.getOrder();
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
auto uniqueContigPerThread = linAttr.getContigPerThread();
assert(order[0] < uniqueContigPerThread.size() &&
"Unexpected uniqueContigPerThread size");
unsigned contiguity = uniqueContigPerThread[order[0]];
Expand All @@ -1247,8 +1248,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto *axisInfo = getAxisInfo(ptr);
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto order = linAttr.getOrder();
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1275,7 +1277,9 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto maskOrder = linAttr.getOrder();
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
94 changes: 33 additions & 61 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,18 @@
#include "triton/Tools/Sys/GetEnv.hpp"

namespace mlir {
namespace {

using namespace triton;
using namespace triton::gpu;

int getParentAxis(Attribute layout, int axis) {
if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
return getParentAxis(sliceEncoding.getParent(), axis);
}
return axis;
}

SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}

} // namespace

// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
auto linearEncoding = toLinearEncoding(getSrcLayout(), getSrcShape());
return linearEncoding.getOrder()[0] == axis;
}

SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getOrder(srcLayout);
auto order = toLinearEncoding(getSrcLayout(), getSrcShape()).getOrder();
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down Expand Up @@ -219,69 +200,59 @@ bool ReduceOpHelper::isSupportedLayout() {
}

unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
return getEncoding().getSizePerThread()[getAxis()];
return getEncoding().getContigPerThread()[getAxis()];
}

unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
sizePerThreads[getAxis()] = 1;
return product<unsigned>(sizePerThreads);
auto contigPerThread = getEncoding().getContigPerThread();
contigPerThread[getAxis()] = 1;
return product<unsigned>(contigPerThread);
}

Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }

unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
return getThreadsPerWarp(getEncoding())[getAxis()];
}

unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()];
return getEncoding().getThreadsPerWarp()[getAxis()];
}

unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
threadsPerWarp[getAxis()] = 1;
return product<unsigned>(threadsPerWarp);
auto nThreads = product(getEncoding().getThreadsPerWarp());
return nThreads / getAxisNumThreadsPerWarpWithUniqueData();
}

// Return the flat numbers of threads computing independent scan results.
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
warpsPerCTA[getAxis()] = 1;
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
}

unsigned ScanLoweringHelper::getAxisNumWarps() {
return getWarpsPerCTA(getEncoding())[getAxis()];
auto nWarps = product(getEncoding().getWarpsPerCTA());
return (nWarps / getAxisNumWarpsWithUniqueData()) *
getNonAxisNumThreadsPerWarp();
}

unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()];
return getEncoding().getWarpsPerCTA()[getAxis()];
}

unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
unsigned axis = getAxis();
return ceil<unsigned>(
getShape()[axis],
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
(contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}

unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
auto rank = contigPerThread.size();
unsigned axis = getAxis();
unsigned numBlocks = 1;
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
for (unsigned i = 0; i < rank; i++) {
if (i == axis)
continue;
numBlocks *=
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
ceil<unsigned>(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] *
warpsPerCTA[i]));
}
return numBlocks;
Expand All @@ -290,7 +261,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
bool ScanLoweringHelper::isSupported() {
// TODO: Support the following cases:
// 1. Scan on non-blocking encodings
if (!isa<BlockedEncodingAttr>(srcEncoding))
if (!isa<BlockedEncodingAttr>(legacyEncoding))
return false;
return true;
}
Expand Down Expand Up @@ -578,42 +549,43 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
return ret;
}

BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
return cast<BlockedEncodingAttr>(srcEncoding);
}

unsigned ScanLoweringHelper::getAxisElementStride() {
auto order = getOrder(getEncoding());
auto order = getOrder();
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getContigPerThread(getEncoding())[dim];
stride *= getEncoding().getContigPerThread()[dim];
}
llvm_unreachable("Axis not found in order");
}

unsigned ScanLoweringHelper::getAxisThreadStride() {
auto order = getOrder(getEncoding());
auto encoding = getEncoding();
auto kThread = StringAttr::get(encoding.getContext(), "lane");
// OOOGHHH This is nasty. We should implement this lowering via LLs natively
// to avoid this
auto threadsPerWarp = encoding.basesPerDim(kThread, /*skipBroadcast=*/false);
auto order = getOrder();
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getEncoding().getThreadsPerWarp()[dim];
stride *= threadsPerWarp[dim];
}
llvm_unreachable("Axis not found in order");
}

unsigned ScanLoweringHelper::getAxisBlockStride() {
auto order = getOrder(getEncoding());
auto order = getOrder();
unsigned stride = 1;
auto sizePerThreads = getSizePerThread(getEncoding());
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
stride *= ceil<unsigned int>(getShape()[dim], contigPerThread[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
Expand Down
Loading

0 comments on commit 06941f4

Please sign in to comment.