Skip to content

Commit

Permalink
[LAYOUTS] Generalise HoistLayoutConversion to work with arbitrary lay…
Browse files Browse the repository at this point in the history
…outs and chains of ops (#5673)

We generalise `HoistLayoutConversion` to lift a given `convert_layout
dot_operand`
above any chain of operations that do not require data movement. We
could totally generalise this in the future to lift it over other ops.
We do
this as a first step to keep the code somewhat similar to the previous
one.

Regarding the previous limitations of `canHoistDotOpEncV2` I did a bit
of archeology:
- The "don't hoist past select" was added in this issue
#2857. I run the repro and
with the recent layout fixes, it now passes.
- The TruncOps being skipped comes from
#2181. I think this is
related with the hack that was removed in
#5044, so now it should work
- Same same for the `UIToFpOp`, this is now supported after #5044
- Mixed dtype hack is not necessary either as now everything works as
expected with the `convert_layout` rework.

We also add proper support for `isPure` for `elementwise_inline_asm` ops

On the location of the code, we just leave it in
`RemoveLayoutConversion.cpp` to
take advantage of the rather generic implementation of `rewriteSlice`.
We could totally
move this pass outside of `remove-layout-conversion`, as it's probably
enough to run
it once. This code will go through further changes in the near future,
so we'll assess this
then.
  • Loading branch information
lezcano authored Jan 31, 2025
1 parent 4ce54b5 commit b3dcc32
Show file tree
Hide file tree
Showing 8 changed files with 438 additions and 421 deletions.
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
Elementwise,
SameOperandsAndResultEncoding,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>
]> {
let summary = "inline assembly applying an elementwise operation to a group of packed elements.";
let description = [{
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {

def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
TransposeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
InferTypeOpWithLayoutEquivalence,
SameOperandsAndResultElementType]> {
let summary = "transpose the descriptor";

Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,12 @@ void ElementwiseInlineAsmOp::getEffects(
SideEffects::DefaultResource::get());
}

Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() {
if (getPure())
return Speculation::Speculatable;
return Speculation::NotSpeculatable;
}

LogicalResult ElementwiseInlineAsmOp::verify() {
if (getNumOperands() >= 1) {
auto tensorType = dyn_cast<RankedTensorType>(getOperand(0).getType());
Expand Down
28 changes: 15 additions & 13 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,15 +463,17 @@ OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
return {};
}

LogicalResult MemDescTransOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
LogicalResult
MemDescTransOp::inferReturnTypes(MLIRContext *context,
std::optional<Location> location,
MemDescTransOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {

// type is the same as the input
auto argTy = cast<MemDescType>(operands[0].getType());
auto argShape = argTy.getShape();
auto order = properties.as<Properties *>()->order.asArrayRef();
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);
auto argTy = cast<MemDescType>(adaptor.getSrc().getType());
auto shape = argTy.getShape();
auto order = adaptor.getOrder();
SmallVector<int64_t> retShape = applyPermutation(shape, order);

auto retEltTy = argTy.getElementType();
Attribute argEncoding = argTy.getEncoding();
Expand All @@ -480,17 +482,17 @@ LogicalResult MemDescTransOp::inferReturnTypes(
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferTransOpEncoding(argEncoding, argShape, order, retEncoding)
->inferTransOpEncoding(argEncoding, shape, order, retEncoding)
.failed()) {
return failure();
}
}
auto memDescTy = cast<MemDescType>(argTy);
inferredReturnTypes.push_back(MemDescType::get(
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(),
memDescTy.getMutableMemory()));
inferredReturnTypes.push_back(
MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(),
argTy.getMutableMemory()));
return success();
}

// LocalAllocOp
void LocalAllocOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
Expand Down
142 changes: 0 additions & 142 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,6 @@ namespace {
// Roughly, whether op is elementwise and thus threads don't need
// to exchange elements. But some ops are not currently supported even though
// they meet that criterion.
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
!isa<arith::ArithDialect>(op->getDialect()))
return false;

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
return false;

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
return false;

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(op)) {
Type opType = getElementTypeOrSelf(op->getOperand(0));
if (opType.isInteger(1))
return false;
}

return true;
}

// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
// is in registers).
bool canHoistDotOpEncV3(Operation *op) {
Expand Down Expand Up @@ -195,116 +165,6 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
}
};

// Move convert-to-dot-operand "up" past elementwise ops:
//
// convert(elementwise(x)) #dot_operand ->
// elementwise(convert(x, #dot_operand)).
//
// The goal is to put the convert right next to the originating load. If we can
// accomplish this, then we can save a shmem round-trip:
//
// Before:
//
// - Load from global into shmem using an async copy.
// - Load from shmem into a #blocked layout.
// - Do elementwise ops over #blocked layout.
// - Convert to #dot_operand (round-trip through shmem).
// - Do dot.
//
// After:
//
// - Load from global into shmem using an async copy (same as before).
// - Load from shmem into a #dot_operand layout.
// - Do elementwise ops over #dot_operand layout.
// - Do dot.
//
// This can also be propagated when we have a constant, instead of a load.
//
// Eliminating the shmem round-trip is such a big win, we're willing to do it
// even if this duplicates work because some of the elementwise ops have uses
// that don't flow into the dot. On the other hand, we only want to do this if
// we can in fact reduce shmem round-trips: For example, simply moving a convert
// up above e.g. an `add` now means we have *two* converts. That's worse,
// unless we can continue moving the converts upwards and eventually merge them.
// So we try to check that this will be beneficial before making any changes.
class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ConvertLayoutOp cvt,
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1)
return failure();

auto srcTy = dyn_cast<RankedTensorType>(src->getResult(0).getType());
if (!srcTy)
return failure();

if (!all_of(src->getOperandTypes(),
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

if (!canHoistDotOpEncV2(src, dotOpEnc))
return failure();

// Check that the conversion is transitively dependent on a load or a
// constant, and all operations between it and the convert are layout
// preserving.
//
// TODO(jlebar): This is accidentally quadratic; we iterate over the whole
// slice but then at the end we only modify one op!
SetVector<Operation *> slice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(cvt.getOperation(), &slice, opt);

// TODO(jlebar): This is too conservative when there are multiple loads in
// the chain. If one of the loads has a non-layout-preserving op and the
// other does not, then we may or may not accept the chain, depending on
// which load gets hit first by getBackwardSlice. For example:
// cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
// cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
bool foundInitializer = false;
// Reverse the slice so that we start directly above the convert and check
// that every op allows hoisting until we find a load or a constant.
for (Operation *currOp : llvm::reverse(slice)) {
if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
foundInitializer = true;
break;
}
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
return failure();
}
if (!foundInitializer)
return failure();

SmallVector<ConvertLayoutOp> newOperands;
for (auto operand : src->getOperands()) {
// We checked earlier that all operands are ranked tensors.
auto operandTy = cast<RankedTensorType>(operand.getType());
Type newCvtTy = RankedTensorType::get(
srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding());
newOperands.push_back(
rewriter.create<ConvertLayoutOp>(cvt.getLoc(), newCvtTy, operand));
}
auto newRet = rewriter.clone(*src);
for (int i = 0; i < newOperands.size(); i++)
newRet->setOperand(i, newOperands[i]);
newRet->getResult(0).setType(RankedTensorType::get(
srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding()));

rewriter.replaceOp(cvt, newRet->getResults());
return success();
}
};

// Rewrite
//
// dot(alloc(trans() #shared1) ->
Expand Down Expand Up @@ -699,8 +559,6 @@ class TritonGPUOptimizeDotOperandsPass
mlir::RewritePatternSet patterns(context);
patterns.add<MMAV3HoistLayoutConversion>(context);
patterns.add<SwizzleShmemConvert>(context);
if (this->hoistLayoutConversion.getValue())
patterns.add<HoistLayoutConversion>(context);
patterns.add<FuseTransMMAV3Plus>(context);
patterns.add<MMAV3UseRegOperand>(context);
patterns.add<InjectTMemCopy>(context);
Expand Down
Loading

0 comments on commit b3dcc32

Please sign in to comment.