From 2f91d11901781df8127c4edcc1b07c5a17271bad Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:40:29 -0800 Subject: [PATCH] [NFC] Clarify comments in BubbleUpExpand shapes pass. (#19837) The comments were written in terms of "fusion", but this pass is not fusing, it is moving reshapes. Make that clear from comments. Signed-off-by: MaheshRavishankar --- .../DispatchCreation/BubbleUpExpandShapes.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index d282bd9f9417..482e255f162f 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -129,24 +129,28 @@ void BubbleUpExpandShapesPass::runOnOperation() { return false; } - // Do not fuse by expand if consumer is dequant. + // Do not push down collapse shape across consumer if it is a bit-extend + // op. The bit-extend ops get cloned into producer dispatches, and the + // `collapse_shape` op going past dequant, prevents this clong. if (IREE::LinalgExt::isBitExtendOp(consumer)) { return false; } - // Do not fuse producer generic op if it has more than one user - // or any reduction iterators. + // If producer generic op is elementwise op, bubble up the expand shape + // past this operation. if (auto producerGenericOp = dyn_cast(producer)) { return llvm::all_of(producerGenericOp.getIteratorTypesArray(), linalg::isParallelIterator); } - // Do not fuse with any producer linalg named ops for now. + // Do not bubble up expand shapes across named ops for now. if (isa(producer)) { return false; } - // Do not fuse with consumer linalg named ops or reductions. + // Do not push expand shapes down across operations with reduction + // iterator types. + // TODO: This condition should be removed. if (auto consumerLinalgOp = dyn_cast(consumer)) { return isa(consumerLinalgOp) && llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),