diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp index fc3c8fbf01ff..bcc94ec951c0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TransposeFusion.cpp @@ -9,7 +9,6 @@ #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -59,44 +58,42 @@ struct FuseTransposeWithAttentionOp final LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp, PatternRewriter &rewriter) const override { - OpOperand *operand = nullptr; - linalg::LinalgOp producer; + OpOperand *transposeOperand = nullptr; + linalg::LinalgOp transposeOp; for (OpOperand *input : attentionOp.getDpsInputOperands()) { if (controlFn && !controlFn(input)) { continue; } - auto maybeProducer = input->get().getDefiningOp(); - if (maybeProducer && maybeProducer.isSingleYieldOp()) { - producer = maybeProducer; - operand = input; + auto maybeTransposeOp = input->get().getDefiningOp(); + if (maybeTransposeOp && isaTranspose(maybeTransposeOp) && + maybeTransposeOp->hasOneUse()) { + transposeOp = maybeTransposeOp; + transposeOperand = input; break; } } - if (!operand) { - return rewriter.notifyMatchFailure(attentionOp, "no operand found"); + if (!transposeOperand) { + return rewriter.notifyMatchFailure(attentionOp, "no transpose operand"); } - int64_t inputIndex = operand->getOperandNumber(); - - auto producerMaps = producer.getIndexingMapsArray(); - AffineMap producerInputMap = producerMaps[0]; - AffineMap producerResultMap = producerMaps[1]; - if (!producerInputMap.isProjectedPermutation() || - !producerResultMap.isPermutation()) { - return failure(); - } + int64_t inputIndex = transposeOperand->getOperandNumber(); + SmallVector perm = getPermutation(transposeOp); + auto invPerm = invertPermutationVector(perm); rewriter.modifyOpInPlace(attentionOp, [&]() { SmallVector newIndexingMaps = attentionOp.getIndexingMapsArray(); - AffineMap consumerInputMap = attentionOp.getMatchingIndexingMap(operand); - AffineMap composedMap = - producerInputMap.compose(inversePermutation(producerResultMap)); - newIndexingMaps[inputIndex] = composedMap.compose(consumerInputMap); + AffineMap inputMap = attentionOp.getMatchingIndexingMap(transposeOperand); + SmallVector newExprs = + applyPermutation(inputMap.getResults(), invPerm); + AffineMap transposedMap = + AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), + newExprs, rewriter.getContext()); + newIndexingMaps[inputIndex] = transposedMap; attentionOp.setIndexingMapsAttr( rewriter.getAffineMapArrayAttr(newIndexingMaps)); - attentionOp.setOperand(inputIndex, producer.getDpsInputs()[0]); + attentionOp.setOperand(inputIndex, transposeOp.getDpsInputs()[0]); }); return success(); diff --git a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir index 52c5b1e41dc4..096c882ab219 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir @@ -208,8 +208,6 @@ util.func public @fuse_generic_gather2( // CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32 // CHECK-NEXT: linalg.yield %[[RES4]] : f32 -// ----- - util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf16>, %k: tensor<2x10x4096x64xf16>, %quantized_v: tensor<2x10x4096x64xi32>, %quant_offset: tensor<10x64xi32>, %quant_scale: tensor<10x64xf32>, %scale: f16) -> tensor<2x10x4096x64xf16> { // Dequantize int-quantization of V %init_dequant = tensor.empty() : tensor<2x10x4096x64xf16> @@ -260,64 +258,3 @@ util.func public @fuse_transpose_attention_to_producer(%q: tensor<2x10x4096x64xf // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> ()> // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[DEQUANT_V]], %[[ARG5]] - -// ----- - -util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> { - %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x8x128x?xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) { - ^bb0(%in: f16, %out: f16): - linalg.yield %in : f16 - } -> tensor<4x8x4x128x?xf16> - %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) { - ^bb0(%arg7: f32): - iree_linalg_ext.yield %arg7 : f32 - } -> tensor<4x8x4x?x32x128xf16> - util.return %1 : tensor<4x8x4x?x32x128xf16> -} - -// CHECK-LABEL: func public @fuse_attention_with_broadcast -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: -// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d7)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)> -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)> -// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] : -// CHECK: util.return %[[ATTENTION]] - - -// ----- - -util.func public @fuse_attention_with_broadcast_transpose(%arg0: tensor<4x?x8x128xf16>, %arg1: tensor<4x8x4x?x32x128xf16>, %arg2: tensor<4x8x4x?x128xf16>, %arg3: f16, %arg4: tensor<4x8x4x?x32x?xf16>, %arg5: tensor<4x8x4x?x32x128xf16>, %arg6: tensor<4x8x4x128x?xf16>) -> tensor<4x8x4x?x32x128xf16> { - %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4, d1)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x8x128xf16>) outs(%arg6 : tensor<4x8x4x128x?xf16>) { - ^bb0(%in: f16, %out: f16): - linalg.yield %in : f16 - } -> tensor<4x8x4x128x?xf16> - %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%arg1, %arg2, %0, %arg3, %arg4 : tensor<4x8x4x?x32x128xf16>, tensor<4x8x4x?x128xf16>, tensor<4x8x4x128x?xf16>, f16, tensor<4x8x4x?x32x?xf16>) outs(%arg5 : tensor<4x8x4x?x32x128xf16>) { - ^bb0(%arg7: f32): - iree_linalg_ext.yield %arg7 : f32 - } -> tensor<4x8x4x?x32x128xf16> - util.return %1 : tensor<4x8x4x?x32x128xf16> -} - -// CHECK-LABEL: func public @fuse_attention_with_broadcast_transpose -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: -// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d7, d6)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d5)>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d7)> -// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)> -// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] : -// CHECK: util.return %[[ATTENTION]]