Skip to content

Commit

Permalink
Reapply "Propagate reshapes through generics with reduction… (#18968)
Browse files Browse the repository at this point in the history
Reland after fixing sdxl int8 regressions via
#19012.

Running CI revealed further performance regressions that have pending
patches: #19325 and
#19326.

This reverts commit 8d3faf8.

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Jan 8, 2025
1 parent 80cbf6b commit a5c3879
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ jobs:
--goldentime-rocm-vae-ms 310.0 \
--goldendispatch-rocm-unet 1602 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 246 \
--goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand All @@ -150,7 +150,7 @@ jobs:
--goldentime-rocm-vae-ms 75.0 \
--goldendispatch-rocm-unet 1602 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 246 \
--goldendispatch-rocm-vae 245 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
--goldensize-rocm-vae-bytes 840000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,24 +134,18 @@ void BubbleUpExpandShapesPass::runOnOperation() {
return false;
}

// Do not fuse producer generic op if it has more than one user
// or any reduction iterators.
if (auto producerGenericOp = dyn_cast<linalg::GenericOp>(producer)) {
return producerGenericOp->hasOneUse() &&
llvm::all_of(producerGenericOp.getIteratorTypesArray(),
linalg::isParallelIterator);
return true;
}

// Do not fuse with any producer linalg named ops for now.
if (isa<linalg::LinalgOp>(producer)) {
return false;
}

// Do not fuse with consumer linalg named ops or reductions.
// Do not fuse with consumer linalg named ops.
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
return isa<linalg::GenericOp>(consumerLinalgOp) &&
llvm::all_of(consumerLinalgOp.getIteratorTypesArray(),
linalg::isParallelIterator);
return isa<linalg::GenericOp>(consumerLinalgOp);
}
// Fuse in all other cases.
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ struct BubbleUpExtractSlicesPass
patterns.insert<BubbleUpExtract>(context);
patterns.insert<SwapExtractSliceOfFill>(context);
tensor::populateFoldTensorEmptyPatterns(patterns, false);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,23 @@ func.func @bubble_up_extract_slice_single_use(%arg0: tensor<131072xi64>, %arg1:
// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] :
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: return %[[GENERIC]]

// -----

func.func @fold_extract_of_expand_of_fill(%arg0 : index, %arg1 : index, %arg2 : index) -> tensor<?xf16> {
%cst0 = arith.constant 0.0 : f16
%0 = tensor.empty(%arg0) : tensor<?xf16>
%2 = linalg.fill ins(%cst0 : f16) outs(%0 : tensor<?xf16>) -> tensor<?xf16>
%3 = tensor.expand_shape %2 [[0, 1]] output_shape[1, %arg1] : tensor<?xf16> into tensor<1x?xf16>
%4 = tensor.extract_slice %3 [0, 0] [1, %arg2] [1, 1] : tensor<1x?xf16> to tensor<?xf16>
func.return %4 : tensor<?xf16>
}

// CHECK-LABEL: func @fold_extract_of_expand_of_fill
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[ARG2]])
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0.0
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST0]] : f16) outs(%[[EMPTY]]
// CHECK: return %[[FILL]]

0 comments on commit a5c3879

Please sign in to comment.