Skip to content

Commit

Permalink
Sync getConvertBackwardSlice from upstream (#3329)
Browse files Browse the repository at this point in the history
Changes come from upstream commit
24b8d43 and
a6b15ef.

Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Feb 2, 2025
1 parent b9ba137 commit 94efbc1
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ getDotEncoding(RankedTensorType tensorType);
// Get backward slice of tensor values starting from the root node along with
// encoding propagation.
LogicalResult getConvertBackwardSlice(
Value root, SetVector<Value> &slice, Attribute rootEncoding,
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);
std::function<bool(Operation *)> stopPropagation = nullptr,
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
nullptr);

LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
ArrayRef<Type> paramTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,18 @@ class LayoutPropagation {
class LayoutRematerialization {
public:
LayoutRematerialization(FuncOp F) : funcOp(F) {}

// Map the original value to the remat'ed one.
void addRematValue(Value old, Attribute encoding, Value newV);
// Get the remat'ed value in the given encoding, if one already exists and
// is different then the layout conversion root.
Value getRematValue(Value value, Attribute encoding) const {
return rematMapping.lookup({value, encoding});
}

bool hasRematValue(Value value, Attribute encoding) {
return rematMapping.contains({value, encoding});
}
// Return the remat'ed value in the given encoding.
Value getRematValue(Value value, Attribute encoding) {
auto it = rematMapping.find({value, encoding});
assert(it != rematMapping.end());
return it->second;
}
void cleanup();
void backwardRematerialization();
void backwardRematerialization(ConvertLayoutOp convertOp);
Expand All @@ -175,6 +176,11 @@ class LayoutRematerialization {
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp);

LogicalResult getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
// Existing tuples of (value, layout) that needs to be updated when recreating
Expand All @@ -186,6 +192,7 @@ class LayoutRematerialization {
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
SetVector<Operation *> opToDelete;
FuncOp funcOp;
DominanceInfo domInfo;
};

void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Expand Down Expand Up @@ -1188,10 +1195,33 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
rewriteSlice(slice, layout, convertOp, mapping);
}

LogicalResult getRematerializableSlice(
Value root, Attribute rootEncoding, SetVector<Value> &slice,
LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr) {
std::function<bool(Operation *)> stopPropagation) {
// Allow re-using existing conversions for a value. Check dominance of any
// reusable materializations against the root value. This is sufficient
// because the conversions are processed in post-order.
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
Value remat = getRematValue(value.get(), encoding);
if (!remat)
return Value();
// `value` can be replaced with an existing rematerialization if it
// dominates the current use of value.
Operation *user = value.getOwner();
if (domInfo.properlyDominates(remat, user)) {
return remat;
}
// Alternatively, if the current use can be sunk below the existing
// rematerialization, then it is okay to use as well. E.g. the current use
// is a conversion that will be folded away when its result is
// rematerialized.
if (isa<ConvertLayoutOp>(user) && remat.getDefiningOp() &&
domInfo.properlyDominates(user, remat.getDefiningOp())) {
return remat;
}
return Value();
};
LogicalResult result = ttgi::getConvertBackwardSlice(
root, slice, rootEncoding, layout, std::move(stopPropagation));
if (result.failed() || slice.empty())
Expand Down Expand Up @@ -1255,7 +1285,7 @@ void LayoutRematerialization::backwardRematerialization(
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getSrc(), targetType.getEncoding(), slice, layout);
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
if (result.failed()) {
LDBG(" getRematerializableSlice failed");
return;
Expand Down Expand Up @@ -1287,9 +1317,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result =
getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(),
slice, layout, isExtOrBroadcastOp);
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
isExtOrBroadcastOp);
if (result.failed())
return;

Expand All @@ -1307,7 +1337,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
if (!srcEncoding)
return;
LogicalResult result = getRematerializableSlice(
op->getOperand(0), srcEncoding, tempSlice, tempLayout);
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
// If we can rematerialize the rest of the ext slice we can ignore this
// ext as it won't need a convert.
if (result.succeeded()) {
Expand Down
75 changes: 52 additions & 23 deletions third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,41 +149,60 @@ static bool isFreeConvert(Operation *op) {
convertOp.getType());
}

LogicalResult
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
DenseSet<std::pair<Value, Attribute>> seen;
SmallVector<std::pair<Value, Attribute>> queue;

auto enqueue = [&](Value operand, Attribute encoding) {
auto x = std::make_pair(operand, encoding);
LogicalResult getConvertBackwardSlice(
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation,
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
DenseSet<std::pair<OpOperand *, Attribute>> seen;
SmallVector<std::pair<OpOperand *, Attribute>> queue;

auto enqueue = [&](OpOperand &operand, Attribute encoding) {
auto x = std::make_pair(&operand, encoding);
if (!seen.insert(x).second) {
return; // Already enqueued, skip
}
queue.push_back(x);
};
enqueue(root, rootEncoding);

auto updateLayout = [&](Value value, Attribute encoding) {
assert(isTensorOrTensorPointerType(value.getType()));
slice.insert(value);
if (layout.find(value) != layout.end()) {
if (layout[value] != encoding)
return failure();
}
layout[value] = encoding;
return success();
};

while (!queue.empty()) {
auto [currentValue, encoding] = queue.back();
auto [currentValueUse, encoding] = queue.back();
Value currentValue = currentValueUse->get();
queue.pop_back();
if (!isTensorOrTensorPointerType(currentValue.getType()))
continue;
slice.insert(currentValue);
if (layout.find(currentValue) != layout.end()) {
if (layout[currentValue] != encoding)
// Skip propagating through for op results for now.
// TODO: enable this based on needs.
if (currentValue.getDefiningOp<scf::ForOp>())
return failure();
if (failed(updateLayout(currentValue, encoding)))
return failure();

Value existing;
if (getExistingConversion &&
(existing = getExistingConversion(*currentValueUse, encoding))) {
if (failed(updateLayout(existing, encoding)))
return failure();
currentValue = existing;
}
layout[currentValue] = encoding;

if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
auto results = ifOp.getResults();
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();

auto thenValue = ifOp.thenYield().getOperand(argIdx);
auto elseValue = ifOp.elseYield().getOperand(argIdx);
OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);
OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx);

enqueue(thenValue, encoding);
enqueue(elseValue, encoding);
Expand All @@ -196,10 +215,11 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
if (result == currentValue ||
!isTensorOrTensorPointerType(result.getType()))
continue;
enqueue(result, encoding);
if (failed(updateLayout(result, encoding)))
return failure();
}
if (isFreeConvert(definingOp)) {
enqueue(definingOp->getOperand(0), encoding);
enqueue(definingOp->getOpOperand(0), encoding);
continue;
}
if (canFoldIntoConversion(definingOp, encoding))
Expand All @@ -208,7 +228,16 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
continue;
if (isa<triton::CatOp>(definingOp))
return failure();
for (Value operand : definingOp->getOperands()) {
if (auto gather = dyn_cast<GatherOp>(definingOp)) {
// Specially handle gather since its transfer function only applies
// between its index operand and result.
auto srcEncoding = ttgi::inferSrcEncoding(gather, encoding);
if (!srcEncoding)
return failure();
enqueue(gather.getIndicesMutable(), srcEncoding);
continue;
}
for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) {
auto srcEncoding = ttgi::inferSrcEncoding(definingOp, encoding);
if (!srcEncoding)
return failure();
Expand All @@ -221,9 +250,9 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
Operation *parentOp = block->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
Value yieldOperand = forOp.getBody()->getTerminator()->getOperand(
OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand(
blockArg.getArgNumber() - forOp.getNumInductionVars());
enqueue(initOperand->get(), encoding);
enqueue(*initOperand, encoding);
enqueue(yieldOperand, encoding);
continue;
}
Expand Down

0 comments on commit 94efbc1

Please sign in to comment.