diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index d282bd9f9417..482e255f162f 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -129,24 +129,28 @@ void BubbleUpExpandShapesPass::runOnOperation() { return false; } - // Do not fuse by expand if consumer is dequant. + // Do not push down collapse shape across consumer if it is a bit-extend + // op. The bit-extend ops get cloned into producer dispatches, and the + // `collapse_shape` op going past dequant, prevents this clong. if (IREE::LinalgExt::isBitExtendOp(consumer)) { return false; } - // Do not fuse producer generic op if it has more than one user - // or any reduction iterators. + // If producer generic op is elementwise op, bubble up the expand shape + // past this operation. if (auto producerGenericOp = dyn_cast(producer)) { return llvm::all_of(producerGenericOp.getIteratorTypesArray(), linalg::isParallelIterator); } - // Do not fuse with any producer linalg named ops for now. + // Do not bubble up expand shapes across named ops for now. if (isa(producer)) { return false; } - // Do not fuse with consumer linalg named ops or reductions. + // Do not push expand shapes down across operations with reduction + // iterator types. + // TODO: This condition should be removed. if (auto consumerLinalgOp = dyn_cast(consumer)) { return isa(consumerLinalgOp) && llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),