From 4b0ca34a768377d648f1efc7f377a852d7a943a9 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 27 Jan 2025 16:08:05 -0800 Subject: [PATCH] Support fusing broadcast transposes with attention (#19828) Attention V operand transposition can sometimes lead to a broadcast + transpose op that won't get folded into the attention op (if it also fails to propagate up during transpose propagation). This change folds these ops into the attention op. --------- Signed-off-by: Ian Wood --- .../LinalgExt/Transforms/TransposeFusion.cpp | 43 +++++++------ .../test/elementwise_op_fusion.mlir | 63 +++++++++++++++++++ 2 files changed, 86 insertions(+), 20 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp index bcc94ec951c0..fc3c8fbf01ff 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp @@ -9,6 +9,7 @@ #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -58,42 +59,44 @@ struct FuseTransposeWithAttentionOp final LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp, PatternRewriter &rewriter) const override { - OpOperand *transposeOperand = nullptr; - linalg::LinalgOp transposeOp; + OpOperand *operand = nullptr; + linalg::LinalgOp producer; for (OpOperand *input : attentionOp.getDpsInputOperands()) { if (controlFn && !controlFn(input)) { continue; } - auto maybeTransposeOp = input->get().getDefiningOp(); - if (maybeTransposeOp && isaTranspose(maybeTransposeOp) && - maybeTransposeOp->hasOneUse()) { - transposeOp = maybeTransposeOp; - transposeOperand = input; + auto maybeProducer = input->get().getDefiningOp(); + if (maybeProducer && maybeProducer.isSingleYieldOp()) { + producer = maybeProducer; + operand = input; break; } } - if (!transposeOperand) { - return rewriter.notifyMatchFailure(attentionOp, "no transpose operand"); + if (!operand) { + return rewriter.notifyMatchFailure(attentionOp, "no operand found"); } - int64_t inputIndex = transposeOperand->getOperandNumber(); - SmallVector perm = getPermutation(transposeOp); - auto invPerm = invertPermutationVector(perm); + int64_t inputIndex = operand->getOperandNumber(); + + auto producerMaps = producer.getIndexingMapsArray(); + AffineMap producerInputMap = producerMaps[0]; + AffineMap producerResultMap = producerMaps[1]; + if (!producerInputMap.isProjectedPermutation() || + !producerResultMap.isPermutation()) { + return failure(); + } rewriter.modifyOpInPlace(attentionOp, [&]() { SmallVector newIndexingMaps = attentionOp.getIndexingMapsArray(); - AffineMap inputMap = attentionOp.getMatchingIndexingMap(transposeOperand); - SmallVector newExprs = - applyPermutation(inputMap.getResults(), invPerm); - AffineMap transposedMap = - AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), - newExprs, rewriter.getContext()); - newIndexingMaps[inputIndex] = transposedMap; + AffineMap consumerInputMap = attentionOp.getMatchingIndexingMap(operand); + AffineMap composedMap = + producerInputMap.compose(inversePermutation(producerResultMap)); + newIndexingMaps[inputIndex] = composedMap.compose(consumerInputMap); attentionOp.setIndexingMapsAttr( rewriter.getAffineMapArrayAttr(newIndexingMaps)); - attentionOp.setOperand(inputIndex, transposeOp.getDpsInputs()[0]); + attentionOp.setOperand(inputIndex, producer.getDpsInputs()[0]); }); return success(); diff --git a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir index 096c882ab219..52c5b1e41dc4 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir @@ -208,6 +208,8 @@ util.func public @fuse_generic_gather2( // CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32 // CHECK-NEXT: linalg.yield %[[RES4]] : f32 +// ----- + util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> { // Dequantize int-quantization of V %init_dequant = tensor.empty() : tensor<2x10x4096x64xf16> @@ -258,3 +260,64 @@ util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]] + +// ----- + +util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x8x128x?xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x8x4x128x?xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) { + ^bb0(%arg7: f32): + iree_linalg_ext.yield %arg7 : f32 + } -> tensor<4x8x4x?x32x128xf16> + util.return %1 : tensor<4x8x4x?x32x128xf16> +} + +// CHECK-LABEL: func public @fuse_attention_with_broadcast +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d7)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] : +// CHECK: util.return %[[ATTENTION]] + + +// ----- + +util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x128xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4, d1)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x8x128xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x8x4x128x?xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) { + ^bb0(%arg7: f32): + iree_linalg_ext.yield %arg7 : f32 + } -> tensor<4x8x4x?x32x128xf16> + util.return %1 : tensor<4x8x4x?x32x128xf16> +} + +// CHECK-LABEL: func public @fuse_attention_with_broadcast_transpose +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d5)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] : +// CHECK: util.return %[[ATTENTION]]