Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LAYOUTS] Implement generically getUniqueContigPerThread #5840

Merged
merged 9 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,26 @@ class ScanLoweringHelper {
explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
srcShape = firstTy.getShape();
srcEncoding = firstTy.getEncoding();
legacyEncoding = firstTy.getEncoding();
srcEncoding = triton::gpu::toLinearEncoding(legacyEncoding, srcShape);
srcElementTypes = op.getElementTypes();
// The codegen does not support different element/thread/warp order so
// we choose one a priori. We choose that of the blocked encoding.
// When we generalise this code to other layouts we'll probably need to
// get rid of all this logic and the *Stride auxiliary methods
// and replace them by transposes and reshapes on the LinearLayout
if (auto blockedEncoding =
dyn_cast<triton::gpu::BlockedEncodingAttr>(legacyEncoding)) {
order = llvm::to_vector(blockedEncoding.getOrder());
} else {
order = srcEncoding.getOrder();
}

for (const auto &t : op.getInputTypes()) {
if (t.getShape() != srcShape) {
op.emitError() << "shape mismatch";
}
if (t.getEncoding() != srcEncoding) {
if (t.getEncoding() != legacyEncoding) {
op.emitError() << "encoding mismatch";
}
}
Expand All @@ -111,12 +123,8 @@ class ScanLoweringHelper {
unsigned getNonAxisNumThreadsPerWarp();
// Return the flat numbers of threads computing independent scan results.
unsigned getNonAxisNumThreadsPerCTA();
// Return the number of warps per CTA along axis dim.
unsigned getAxisNumWarps();
// Return the number of warps per CTA along axis dim with unique data.
unsigned getAxisNumWarpsWithUniqueData();
// Return the number of threads per warp along axis dim.
unsigned getAxisNumThreadsPerWarp();
// Return the number of threads per warp along axis dim with unique data.
unsigned getAxisNumThreadsPerWarpWithUniqueData();
// Return the number of blocks along axis dim.
Expand All @@ -139,18 +147,20 @@ class ScanLoweringHelper {
Location getLoc() { return scanOp.getLoc(); }
unsigned getAxis() { return scanOp.getAxis(); }
bool getReverse() { return scanOp.getReverse(); }
triton::gpu::BlockedEncodingAttr getEncoding();
triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; }
llvm::ArrayRef<int64_t> getShape() { return srcShape; }
unsigned getNumOperands() { return scanOp.getNumOperands(); }
SmallVector<Type> getElementTypes() { return srcElementTypes; }
Attribute getSrcLayout() { return srcEncoding; }
SmallVector<unsigned> getOrder() { return order; }
Region &getCombineOp();

private:
triton::ScanOp scanOp;
Attribute srcEncoding;
triton::gpu::LinearEncodingAttr srcEncoding;
Attribute legacyEncoding;
llvm::ArrayRef<int64_t> srcShape;
SmallVector<Type> srcElementTypes;
SmallVector<unsigned> order;
};

// Helper class for lowering `tt.gather` operations. This class shares lowering
Expand Down
33 changes: 24 additions & 9 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace triton {

// Bitwidth of pointers
constexpr int kPtrBitWidth = 64;
// Max shmem LDS/STS instruction in bits
constexpr int kMaxShmemVecBitLength = 128;

static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand Down Expand Up @@ -79,15 +81,17 @@ std::pair<unsigned, unsigned>
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
const auto &inOrd = gpu::getOrder(srcLayout);
const auto &outOrd = gpu::getOrder(dstLayout);

auto srcLinAttr = gpu::toLinearEncoding(srcLayout, srcTy.getShape());
auto dstLinAttr = gpu::toLinearEncoding(dstLayout, dstTy.getShape());
auto inOrd = srcLinAttr.getOrder();
auto outOrd = dstLinAttr.getOrder();

unsigned rank = srcTy.getRank();

unsigned srcContigPerThread =
gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
unsigned dstContigPerThread =
gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]];
unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]];
// TODO: Fix the legacy issue that outOrd[0] == 0 always means
// that we cannot do vectorization.
unsigned innerDim = rank - 1;
unsigned inVec = outOrd[0] != innerDim ? 1
Expand Down Expand Up @@ -117,8 +121,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
Attribute dstLayout = dstTy.getEncoding();

assert(cvtNeedsSharedMemory(srcTy, dstTy));

const auto &outOrd = gpu::getOrder(dstLayout);
auto outOrd = gpu::toLinearEncoding(dstLayout, dstTy.getShape()).getOrder();
scratchConfig.order = outOrd;

std::tie(scratchConfig.inVec, scratchConfig.outVec) =
Expand All @@ -129,6 +132,18 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
// is the max vectorisation
auto inBitWidth = isa<PointerType>(srcTy.getElementType())
? kPtrBitWidth
: srcTy.getElementTypeBitWidth();
auto outBitWidth = isa<PointerType>(dstTy.getElementType())
? kPtrBitWidth
: dstTy.getElementTypeBitWidth();
scratchConfig.inVec =
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
scratchConfig.outVec =
std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth);

// No padding is required if the tensor is 1-D, or if all dimensions except
// the first accessed dimension have a size of 1.
Expand Down
22 changes: 13 additions & 9 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1222,15 +1222,16 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();

// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
// FIXME: This is not as good as it could be, as we don't need to restrict
// the analysis to one dimension. We should determine contiguity on the
// flattenOuts() layout
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto order = linAttr.getOrder();
unsigned align = getPtrAlignment(ptr);

auto uniqueContigPerThread =
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
auto uniqueContigPerThread = linAttr.getContigPerThread();
assert(order[0] < uniqueContigPerThread.size() &&
"Unexpected uniqueContigPerThread size");
unsigned contiguity = uniqueContigPerThread[order[0]];
Expand All @@ -1247,8 +1248,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto *axisInfo = getAxisInfo(ptr);
if (!axisInfo)
return 1;
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto order = linAttr.getOrder();
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
Expand All @@ -1275,7 +1277,9 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto linAttr =
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
auto maskOrder = linAttr.getOrder();
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
<< alignment);
Expand Down
94 changes: 33 additions & 61 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,18 @@
#include "triton/Tools/Sys/GetEnv.hpp"

namespace mlir {
namespace {

using namespace triton;
using namespace triton::gpu;

int getParentAxis(Attribute layout, int axis) {
if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
return getParentAxis(sliceEncoding.getParent(), axis);
}
return axis;
}

SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}

} // namespace

// TODO(jlebar): Move this class into namespace triton.
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
auto linearEncoding = toLinearEncoding(getSrcLayout(), getSrcShape());
return linearEncoding.getOrder()[0] == axis;
}

SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getOrder(srcLayout);
auto order = toLinearEncoding(getSrcLayout(), getSrcShape()).getOrder();
auto it = std::find(order.begin(), order.end(), axis);
// delete the axis from order
order.erase(it);
Expand Down Expand Up @@ -219,69 +200,59 @@ bool ReduceOpHelper::isSupportedLayout() {
}

unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
return getEncoding().getSizePerThread()[getAxis()];
return getEncoding().getContigPerThread()[getAxis()];
}

unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
sizePerThreads[getAxis()] = 1;
return product<unsigned>(sizePerThreads);
auto contigPerThread = getEncoding().getContigPerThread();
contigPerThread[getAxis()] = 1;
return product<unsigned>(contigPerThread);
}

Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }

unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
return getThreadsPerWarp(getEncoding())[getAxis()];
}

unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()];
return getEncoding().getThreadsPerWarp()[getAxis()];
}

unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
threadsPerWarp[getAxis()] = 1;
return product<unsigned>(threadsPerWarp);
auto nThreads = product(getEncoding().getThreadsPerWarp());
return nThreads / getAxisNumThreadsPerWarpWithUniqueData();
}

// Return the flat numbers of threads computing independent scan results.
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
warpsPerCTA[getAxis()] = 1;
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
}

unsigned ScanLoweringHelper::getAxisNumWarps() {
return getWarpsPerCTA(getEncoding())[getAxis()];
auto nWarps = product(getEncoding().getWarpsPerCTA());
return (nWarps / getAxisNumWarpsWithUniqueData()) *
getNonAxisNumThreadsPerWarp();
}

unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()];
return getEncoding().getWarpsPerCTA()[getAxis()];
}

unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
unsigned axis = getAxis();
return ceil<unsigned>(
getShape()[axis],
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
(contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}

unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
auto rank = contigPerThread.size();
unsigned axis = getAxis();
unsigned numBlocks = 1;
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
for (unsigned i = 0; i < rank; i++) {
if (i == axis)
continue;
numBlocks *=
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
ceil<unsigned>(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] *
warpsPerCTA[i]));
}
return numBlocks;
Expand All @@ -290,7 +261,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
bool ScanLoweringHelper::isSupported() {
// TODO: Support the following cases:
// 1. Scan on non-blocking encodings
if (!isa<BlockedEncodingAttr>(srcEncoding))
if (!isa<BlockedEncodingAttr>(legacyEncoding))
return false;
return true;
}
Expand Down Expand Up @@ -579,42 +550,43 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
return ret;
}

BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
return cast<BlockedEncodingAttr>(srcEncoding);
}

unsigned ScanLoweringHelper::getAxisElementStride() {
auto order = getOrder(getEncoding());
auto order = getOrder();
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getContigPerThread(getEncoding())[dim];
stride *= getEncoding().getContigPerThread()[dim];
}
llvm_unreachable("Axis not found in order");
}

unsigned ScanLoweringHelper::getAxisThreadStride() {
auto order = getOrder(getEncoding());
auto encoding = getEncoding();
auto kThread = StringAttr::get(encoding.getContext(), "lane");
// OOOGHHH This is nasty. We should implement this lowering via LLs natively
// to avoid this
auto threadsPerWarp = encoding.basesPerDim(kThread, /*skipBroadcast=*/false);
auto order = getOrder();
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getEncoding().getThreadsPerWarp()[dim];
stride *= threadsPerWarp[dim];
}
llvm_unreachable("Axis not found in order");
}

unsigned ScanLoweringHelper::getAxisBlockStride() {
auto order = getOrder(getEncoding());
auto order = getOrder();
unsigned stride = 1;
auto sizePerThreads = getSizePerThread(getEncoding());
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
stride *= ceil<unsigned int>(getShape()[dim], contigPerThread[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
Expand Down
Loading
Loading