Skip to content

Commit

Permalink
[GPU] Use affine.delinearize_index for MMA tiles and vector distribut…
Browse files Browse the repository at this point in the history
…ion (#19228)

This commit updates some by-hand delinearizations in MMA tile generation
and vector distribution to use `affine.delinearize_index` instead.

The main tricky thing here is that a lot of that MMA code would use `(id
/ stride) % size`, whereas delinearize's outputs all have the form `(id
% stride) / nextStride`. In all the cases at issue, we could use a
utility to convert arrays of sizes and strides to a permutation on a
delinearization basis.

In order to not break existing tests, the trivial-loop detector had to
be manually instrumented to support `delinearize_index` (and I got
`util.assume.int` while I was there). (I suspect there're a few other
cases, and that, long-term, that detector should be using one of the
bounds interfaces, but that's not this PR)

---------

Co-authored-by: Quinn Dawkins <[email protected]>
  • Loading branch information
krzysz00 and qedawkins authored Jan 28, 2025
1 parent aa9f8c5 commit ecd67d9
Show file tree
Hide file tree
Showing 20 changed files with 418 additions and 342 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,6 @@ namespace mlir::iree_compiler {
using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

/// Helper to linearize the given |ids| with maximum values given as |sizes|.
/// Gets the element ID in terms of |elementCount| and adds the element
/// |offset|. For example,
///
/// IDs = [d0, d1, d2, d3]
/// sizes = [s0, s1, s2, s3]
/// linear_index = d0 * (s1 * s2 * s3)
/// + d1 * (s2 * s3)
/// + d2 * (s3)
/// + d3
/// return element_index = linear_index * |elementCount| + |offset|;
static Value linearizeIndex(OpBuilder &builder, Value offset,
ArrayRef<OpFoldResult> ids, ArrayRef<int64_t> sizes,
int64_t elementCount) {
SmallVector<AffineExpr> exprs(ids.size() + 1);
bindSymbolsList(builder.getContext(), MutableArrayRef{exprs});
AffineExpr idExpr = builder.getAffineConstantExpr(0);

for (int i = 0, e = ids.size(); i < e; ++i) {
if (sizes[i] > 1) {
// Multiply by the residual threads along this dimension (which must be
// faster changing than all previous dimensions) and add the id for this
// dimension.
idExpr = idExpr * builder.getAffineConstantExpr(sizes[i]) + exprs[i];
}
}
idExpr = idExpr * builder.getAffineConstantExpr(elementCount);
idExpr = idExpr + exprs.back();
SmallVector<OpFoldResult> mapArgs(ids);
mapArgs.push_back(offset);
return affine::makeComposedAffineApply(
builder, offset.getLoc(),
AffineMap::get(0, mapArgs.size(), idExpr), mapArgs)
.getResult();
}

/// Given a set of base transfer |indices|, |offsets| for the batch/outer
/// dimensions, and distributed warp and thread indices, computes the indices
/// of the distributed transfer operation based on the |vectorLayout|.
Expand All @@ -94,16 +58,28 @@ static SmallVector<Value> getTransferIndicesFromNestedLayout(
continue;
}
unsigned pos = cast<AffineDimExpr>(dim).getPosition();
SmallVector<OpFoldResult> ids = {
warpIndices[i], b.getIndexAttr(batchOffsets[i]),
b.getIndexAttr(outerVectorOffsets[i]), threadIndices[i]};
Value offset = indices[pos];
int64_t elementCount = vectorLayout.getElementTile()[i];
Location loc = offset.getLoc();
SmallVector<Value> ids = {
warpIndices[i], b.create<arith::ConstantIndexOp>(loc, batchOffsets[i]),
b.create<arith::ConstantIndexOp>(loc, outerVectorOffsets[i]),
threadIndices[i], offset};
// The order in which a vector dimension is "tiled" is
// subgroups -> batches -> outer vectors -> threads -> elements
SmallVector<int64_t> sizes = {
vectorLayout.getSubgroupTile()[i], vectorLayout.getBatchTile()[i],
vectorLayout.getOuterTile()[i], vectorLayout.getThreadTile()[i]};
slicedIndices[pos] = linearizeIndex(b, indices[pos], ids, sizes,
vectorLayout.getElementTile()[i]);
vectorLayout.getOuterTile()[i], vectorLayout.getThreadTile()[i],
elementCount};
// The offset is often not an offset within `elementCount`, so, in general,
// we can't mark this `disjoint`. However, if `offset` is known to be
// a constant less than `elementCount`, we can do this, unlocking
// potential optimizations.
bool disjoint = false;
if (std::optional<int64_t> offsetConst = getConstantIntValue(offset))
disjoint = *offsetConst < elementCount;
slicedIndices[pos] =
b.create<affine::AffineLinearizeIndexOp>(loc, ids, sizes, disjoint);
}
return slicedIndices;
}
Expand All @@ -123,19 +99,21 @@ getElementVectorTileShape(NestedLayoutAttr vectorLayout) {

/// Computes the warp and thread indices for the given vector layout from a
/// single linearized thread ID.
static void populateWarpAndThreadIndices(RewriterBase &rewriter, Value threadId,
int64_t subgroupSize,
NestedLayoutAttr vectorLayout,
SmallVector<Value> &warpIndices,
SmallVector<Value> &threadIndices) {
static LogicalResult populateWarpAndThreadIndices(
RewriterBase &rewriter, Value threadId, int64_t subgroupSize,
NestedLayoutAttr vectorLayout, SmallVector<Value> &warpIndices,
SmallVector<Value> &threadIndices) {
// The delinearized thread IDs are returned from outer most to inner most,
// i.e. before applying the layout described dimensions ordering.
int64_t rank = vectorLayout.getRank();
SmallVector<Value> threadIds =
vectorLayout.computeThreadIds(threadId, subgroupSize, rewriter);
if (threadIds.empty() && rank != 0)
return failure();
warpIndices = SmallVector<Value>(threadIds.begin(), threadIds.begin() + rank);
threadIndices = SmallVector<Value>(threadIds.begin() + rank,
threadIds.begin() + 2 * rank);
return success();
}

namespace {
Expand Down Expand Up @@ -189,8 +167,12 @@ struct DistributeTransferRead final
VectorValue acc = cast<VectorValue>(zero);

SmallVector<Value> warpIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, vectorLayout,
warpIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
vectorLayout, warpIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
readOp, "warp or thread tiles have overlapping strides");
}

ValueRange indices = readOp.getIndices();
SmallVector<int64_t> strides(rank, 1);
Expand Down Expand Up @@ -259,8 +241,12 @@ struct DistributeTransferWrite final
int64_t rank = vectorLayout.getRank();

SmallVector<Value> warpIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, vectorLayout,
warpIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
vectorLayout, warpIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
writeOp, "warp or thread tiles have overlapping strides");
}

Value distributedVector =
getDistributed(rewriter, writeOp.getVector(), vectorLayout);
Expand Down Expand Up @@ -1282,8 +1268,12 @@ struct DistributeStep final : OpDistributionPattern<vector::StepOp> {
stepOp, "missing nested layout for step op result");
}
SmallVector<Value> subgroupIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, resultLayout,
subgroupIndices, threadIndices);
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
resultLayout, subgroupIndices,
threadIndices))) {
return rewriter.notifyMatchFailure(
stepOp, "warp or thread tiles have overlapping strides");
}

SmallVector<int64_t> undistributedShape =
resultLayout.getUndistributedPackedShape();
Expand Down
Loading

0 comments on commit ecd67d9

Please sign in to comment.