From 50a70876310b0a41b0e40a94ef7618565dbec243 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:18:37 -0600 Subject: [PATCH] Add pattern to convert generic conv ops to IGEMM (#19798) This PR removes the named op patterns to convert convs to IGEMM and replaces them with a generic pattern that works for all supported convs. A new utility function that populates the shared details required for setting lowering config and doing the IGEMM computation is added. The PR is currently using a default true flag `iree-gpu-use-tile-and-fuse-generic-convolution` . The idea is that since a lot more convolutions will go down the IGEMM path with this PR if any of them run into issues we can turn the flag off by default rather then needing to revert the whole PR. If after some time we find that there are no issues then we can drop the flag and have this happening always. --------- Signed-off-by: Nirvedh Meshram --- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 33 +- .../Transforms/ConvertConv2DToIm2ColOp.cpp | 359 ++++++++---------- .../Transforms/test/conv2d_to_im2col.mlir | 58 +++ .../Dialect/LinalgExt/Utils/BUILD.bazel | 1 + .../Dialect/LinalgExt/Utils/CMakeLists.txt | 1 + .../Dialect/LinalgExt/Utils/Utils.cpp | 222 +++++++++++ .../compiler/Dialect/LinalgExt/Utils/Utils.h | 32 ++ 7 files changed, 493 insertions(+), 213 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 81bb71c3052c..d90e18b90d36 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -31,6 +31,14 @@ namespace mlir::iree_compiler::IREE::GPU { +// TODO (nirvedhmeshram) : This flag allows a lot more convolutions to use IGEMM +// so drop this flag after sufficient use with no issues. +llvm::cl::opt clGPUUseTileAndFuseGenericConvolution( + "iree-gpu-use-tile-and-fuse-generic-convolution", + llvm::cl::desc( + "enable the tile and fuse pipeline for generic convolutions"), + llvm::cl::init(true)); + constexpr int64_t kCacheLineSizeBits = 128 * 8; constexpr int64_t kPreferredCopyNumBits = 128; @@ -371,12 +379,25 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target, return failure(); LDBG("IGEMM TileAndFuse Config"); - FailureOr> igemmContractionMaps = - LinalgExt::getIGEMMContractionIndexingMaps(linalgOp); - FailureOr> igemmLoopBounds = - LinalgExt::getIGEMMLoopBounds(linalgOp); - FailureOr> igemmOperands = - LinalgExt::getIGEMMOperands(linalgOp); + FailureOr> igemmContractionMaps; + FailureOr> igemmLoopBounds; + FailureOr> igemmOperands; + if (!clGPUUseTileAndFuseGenericConvolution) { + igemmContractionMaps = LinalgExt::getIGEMMContractionIndexingMaps(linalgOp); + igemmLoopBounds = LinalgExt::getIGEMMLoopBounds(linalgOp); + igemmOperands = LinalgExt::getIGEMMOperands(linalgOp); + } else { + FailureOr igemmGenericConvDetails = + LinalgExt::getIGEMMGenericConvDetails(linalgOp); + if (failed(igemmGenericConvDetails)) { + LDBG("Unsupported generic convolution type"); + return failure(); + } + igemmContractionMaps = igemmGenericConvDetails->igemmContractionMaps; + igemmLoopBounds = igemmGenericConvDetails->igemmLoopBounds; + igemmOperands = igemmGenericConvDetails->igemmOperands; + } + if (failed(igemmContractionMaps) || failed(igemmLoopBounds) || failed(igemmOperands)) { LDBG("Unsupported convolution type"); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp index 905252b31b3f..e37d3f9488f4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp @@ -17,9 +17,8 @@ namespace mlir::iree_compiler::IREE::LinalgExt { #define GEN_PASS_DEF_CONVERTCONV2DTOIM2COLOPPASS #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc" -static bool hasAllOneValues(DenseIntElementsAttr attr) { - return llvm::all_of( - attr, [](APInt element) { return element.getSExtValue() == 1; }); +static bool hasAllOneValues(ArrayRef attr) { + return llvm::all_of(attr, [](int64_t element) { return element == 1; }); } static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { @@ -36,12 +35,37 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { return builder.create(loc, x, y); } +// TODO : Upstream utility that does this pruning is broken for LinalgOp. Drop +// this if that gets fixed. +static SmallVector getPrunedAttributeList(linalg::LinalgOp op) { + const StringLiteral memoAttr = + linalg::LinalgDialect::kMemoizedIndexingMapsAttrName; + SmallVector prunedAttributeList; + for (auto attr : op->getDiscardableAttrs()) { + if (attr.getName() != memoAttr) { + prunedAttributeList.push_back(attr); + } + } + return prunedAttributeList; +} + +// Helper to convert a shape into basis for im2col op. +static SmallVector getBasisFromShape(ArrayRef shape) { + SmallVector basis(shape.size()); + int64_t cummulativeProduct = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + basis[i] = cummulativeProduct; + cummulativeProduct *= shape[i]; + } + return basis; +} + namespace { using ControlFnTy = std::function; - -// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) +// Converts non-depthwise convs into into linalg.generic (for img2col packing) // and linalg.matmul. +// The following explains this for a linalg.conv_2d_nhwc_hwcf op. // // A convolution operaton can be written as a matrix-matrix multiplication by // unfolding the cross correlation between input and filter and explicitly copy @@ -73,219 +97,141 @@ using ControlFnTy = std::function; // multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in // the N input. For the case where N > 1 its a batched matrxi-matrix // multplication. -class ConvertConv2DNhwcHwcf final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - ConvertConv2DNhwcHwcf(MLIRContext *context, - std::optional controlFn) - : OpRewritePattern(context), - controlFn(controlFn) {} +class ConvertConvGeneric final + : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + ConvertConvGeneric(MLIRContext *context, std::optional controlFn) + : OpInterfaceRewritePattern(context), controlFn(controlFn) {} + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { - if (controlFn.has_value() && !controlFn.value()(convOp)) { - return rewriter.notifyMatchFailure(convOp, "controlFn failed."); + if (controlFn.has_value() && !controlFn.value()(linalgOp)) { + return rewriter.notifyMatchFailure(linalgOp, "controlFn failed."); } - auto inputType = llvm::cast(convOp.getInputs()[0].getType()); - auto filterType = llvm::cast(convOp.getInputs()[1].getType()); - auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); - - if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected 'filterType' and 'inputType' to have static shapes."; - }); + auto igemmConvDetailsOrFailure = + LinalgExt::getIGEMMGenericConvDetails(linalgOp); + if (failed(igemmConvDetailsOrFailure)) + return rewriter.notifyMatchFailure(linalgOp, + "Failed to extract IGEMM details"); + + LinalgExt::IGEMMGenericConvDetails igemmConvDetails = + *igemmConvDetailsOrFailure; + + SmallVector igemmContractionMaps = + igemmConvDetails.igemmContractionMaps; + mlir::linalg::ConvolutionDimensions convDims = igemmConvDetails.convDims; + SmallVector filterReassocIndices = + igemmConvDetails.filterReassocIndices; + bool isOutputChannelFirst = igemmConvDetails.isOutputChannelFirst; + SmallVector igemmLoopBounds = igemmConvDetails.igemmLoopBounds; + SmallVector igemmLoopIterators = + igemmConvDetails.igemmLoopIterators; + + Value input = linalgOp.getDpsInputs()[0]; + Value filter = linalgOp.getDpsInputs()[1]; + Value output = linalgOp.getDpsInits()[0]; + auto inputType = llvm::cast(input.getType()); + auto filterType = llvm::cast(filter.getType()); + auto outputType = llvm::cast(output.getType()); + + ArrayRef filterShape = filterType.getShape(); + ArrayRef outputShape = outputType.getShape(); + ArrayRef inputShape = inputType.getShape(); + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + AffineMap inputMap = indexingMaps[0]; + AffineMap filterMap = indexingMaps[1]; + AffineMap outputMap = indexingMaps[2]; + + SmallVector kernelSizes; + for (auto filterLoop : convDims.filterLoop) { + std::optional maybeDim = filterMap.getResultPosition( + getAffineDimExpr(filterLoop, filterMap.getContext())); + if (!maybeDim) { + return rewriter.notifyMatchFailure(linalgOp, + "Failed to infer filter shape."); + } + kernelSizes.push_back( + rewriter.getIndexAttr(filterShape[maybeDim.value()])); } - // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) { - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected no dilations (expected dilations to all be one)."; - }); + // Shape of the resulting tensor from im2col. + SmallVector colTensorShape; + SmallVector batchPos; + for (auto batch : convDims.batch) { + std::optional maybeBatch = inputMap.getResultPosition( + getAffineDimExpr(batch, inputMap.getContext())); + if (!maybeBatch) { + return rewriter.notifyMatchFailure(linalgOp, + "Failed to infer batch shape."); + } + batchPos.push_back(maybeBatch.value()); + colTensorShape.push_back(inputShape[maybeBatch.value()]); } - Value input = convOp.getInputs()[0]; - Value filter = convOp.getInputs()[1]; - Value output = convOp.getOutputs()[0]; - - auto filterShape = filterType.getShape(); - auto outputShape = outputType.getShape(); - - const int n = outputShape[0]; - const int oh = outputShape[1]; - const int ow = outputShape[2]; - const int oc = outputShape[3]; - const int fh = filterShape[0]; - const int fw = filterShape[1]; - const int ic = filterShape[2]; - - auto loc = convOp.getLoc(); - - SmallVector colTensorShape = {n, oh, ow, fh * fw * ic}; - - SmallVector outputReassocIndices = {{0}, {1, 2}, {3}}; - - Value colTensor = rewriter.create( - loc, colTensorShape, inputType.getElementType()); - SmallVector strides(convOp.getStrides().getValues()); - SmallVector dilations(convOp.getDilations().getValues()); - SmallVector kernelSize = {rewriter.getIndexAttr(fh), - rewriter.getIndexAttr(fw)}; - OpFoldResult zero = rewriter.getIndexAttr(0); - OpFoldResult one = rewriter.getIndexAttr(1); - SmallVector mOffset = {zero, zero}; - SmallVector mBasis = {rewriter.getIndexAttr(ow), one}; - SmallVector kOffset = {zero}; - SmallVector kBasis = {one}; - SmallVector batchPos = {0}; - SmallVector mPos = {1, 2}; - SmallVector kPos = {3}; - Value img2ColTensor = rewriter - .create( - loc, input, /*output=*/colTensor, strides, - dilations, kernelSize, mOffset, mBasis, - kOffset, kBasis, batchPos, mPos, kPos) - .getResult(0); - - SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; - auto reshapedFilterType = - RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); - - Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, filterReassocIndices); - - SmallVector indexingMaps = - getIGEMMContractionIndexingMaps(convOp).value(); - auto parallel = utils::IteratorType::parallel; - auto reduction = utils::IteratorType::reduction; - SmallVector genericIterators = { - parallel, parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create( - loc, outputType, - /*inputs=*/ValueRange{img2ColTensor, reshapedFilter}, - /*outputs=*/ValueRange{output}, indexingMaps, genericIterators, - [](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value lhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[0], - args[2].getType(), - /*isUnsignedCast=*/false); - Value rhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[1], - args[2].getType(), - /*isUnsignedCast=*/false); - Value mul = createMul(nestedLoc, lhs, rhs, nestedBuilder); - Value add = createAdd(nestedLoc, mul, args[2], nestedBuilder); - nestedBuilder.create(nestedLoc, add); - }, - linalg::getPrunedAttributeList(convOp)); - Value result = genericOp.getResults().front(); - - rewriter.replaceOp(convOp, result); - - return success(); - } - -private: - std::optional controlFn; -}; - -// For nchw, because the channels are to the left of the image shape dimensions, -// the position of the contraction dimension in the resulting matmul is -// reversed. This swaps the LHS and RHS of the matmul when compared with nhwc -// (i.e. (D, C x Kh x Kw) * (C x Kh x Kw, Ho x Wo)) -class ConvertConv2DNchwFchw final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - ConvertConv2DNchwFchw(MLIRContext *context, - std::optional controlFn) - : OpRewritePattern(context), - controlFn(controlFn) {} - - LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, - PatternRewriter &rewriter) const override { - if (controlFn.has_value() && !controlFn.value()(convOp)) { - return rewriter.notifyMatchFailure(convOp, "controlFn failed."); + SmallVector mPos; + SmallVector mShape; + for (auto outputImage : convDims.outputImage) { + for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) { + if (e.isFunctionOfDim(outputImage)) { + mPos.push_back(idx); + } + } + for (auto [idx, e] : llvm::enumerate(outputMap.getResults())) { + if (e.isFunctionOfDim(outputImage)) { + mShape.push_back(outputShape[idx]); + colTensorShape.push_back(outputShape[idx]); + } + } } - auto inputType = llvm::cast(convOp.getInputs()[0].getType()); - auto filterType = llvm::cast(convOp.getInputs()[1].getType()); - auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); - - if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected 'filterType' and 'inputType' to have static shapes."; - }); + SmallVector kPos; + for (auto reductionDim : convDims.inputChannel) { + for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) { + if (e.isFunctionOfDim(reductionDim)) { + kPos.push_back(idx); + } + } } - - // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected no dilations (expected dilations to all be one)."; - }); - - Value input = convOp.getInputs()[0]; - Value filter = convOp.getInputs()[1]; - Value output = convOp.getOutputs()[0]; - - auto filterShape = filterType.getShape(); - auto outputShape = outputType.getShape(); - - const int n = outputShape[0]; - const int oc = outputShape[1]; - const int oh = outputShape[2]; - const int ow = outputShape[3]; - const int ic = filterShape[1]; - const int fh = filterShape[2]; - const int fw = filterShape[3]; - - auto loc = convOp.getLoc(); - - SmallVector colTensorShape = {n, oh, ow, fh * fw * ic}; - + // The index at which the reduction dimension bounds starts in + // igemmLoopBounds. + int64_t reductionBoundIndex = convDims.batch.size() + + convDims.outputImage.size() + + convDims.outputChannel.size(); + SmallVector kShape(igemmLoopBounds.begin() + reductionBoundIndex, + igemmLoopBounds.end()); + colTensorShape.insert(colTensorShape.end(), kShape.begin(), kShape.end()); + + SmallVector mBasis = + getAsIndexOpFoldResult(getContext(), getBasisFromShape(mShape)); + SmallVector kBasis = + getAsIndexOpFoldResult(getContext(), getBasisFromShape(kShape)); + + SmallVector kOffset(kBasis.size(), rewriter.getIndexAttr(0)); + SmallVector mOffset(mBasis.size(), rewriter.getIndexAttr(0)); + auto loc = linalgOp.getLoc(); Value colTensor = rewriter.create( loc, colTensorShape, inputType.getElementType()); - SmallVector strides(convOp.getStrides().getValues()); - SmallVector dilations(convOp.getDilations().getValues()); - SmallVector kernelSize = {rewriter.getIndexAttr(fh), - rewriter.getIndexAttr(fw)}; - OpFoldResult zero = rewriter.getIndexAttr(0); - OpFoldResult one = rewriter.getIndexAttr(1); - SmallVector mOffset = {zero, zero}; - SmallVector mBasis = {rewriter.getIndexAttr(ow), one}; - SmallVector kOffset = {zero}; - SmallVector kBasis = {one}; - SmallVector batchPos = {0}; - SmallVector mPos = {2, 3}; - SmallVector kPos = {1}; - Value img2ColTensor = rewriter - .create( - loc, input, /*output=*/colTensor, strides, - dilations, kernelSize, mOffset, mBasis, - kOffset, kBasis, batchPos, mPos, kPos) - .getResult(0); + Value img2ColTensor = + rewriter + .create( + loc, input, /*output=*/colTensor, convDims.strides, + convDims.dilations, kernelSizes, mOffset, mBasis, kOffset, + kBasis, batchPos, mPos, kPos) + .getResult(0); - SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; - auto reshapedFilterType = - RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType()); Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, filterReassocIndices); + loc, filter, filterReassocIndices); - SmallVector indexingMaps = - getIGEMMContractionIndexingMaps(convOp).value(); - auto parallel = utils::IteratorType::parallel; - auto reduction = utils::IteratorType::reduction; - SmallVector genericIterators = { - parallel, parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create( + auto genericGEMMOp = rewriter.create( loc, outputType, - /*inputs=*/ValueRange{reshapedFilter, img2ColTensor}, - /*outputs=*/ValueRange{output}, indexingMaps, genericIterators, + /*inputs=*/ + isOutputChannelFirst ? ValueRange{reshapedFilter, img2ColTensor} + : ValueRange{img2ColTensor, reshapedFilter}, + /*outputs=*/ValueRange{output}, igemmContractionMaps, + igemmLoopIterators, [](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { Value lhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[0], args[2].getType(), @@ -296,12 +242,11 @@ class ConvertConv2DNchwFchw final Value mul = createMul(nestedLoc, lhs, rhs, nestedBuilder); Value add = createAdd(nestedLoc, mul, args[2], nestedBuilder); nestedBuilder.create(nestedLoc, add); - }, - linalg::getPrunedAttributeList(convOp)); - Value result = genericOp.getResults().front(); - - rewriter.replaceOp(convOp, result); + }); + genericGEMMOp->setDiscardableAttrs(getPrunedAttributeList(linalgOp)); + Value result = genericGEMMOp.getResults().front(); + rewriter.replaceOp(linalgOp, result); return success(); } @@ -327,8 +272,8 @@ struct ConvertConv2DToIm2ColOpPass final void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns, std::optional controlFn) { - patterns.insert( - patterns.getContext(), std::move(controlFn)); + patterns.insert(patterns.getContext(), + std::move(controlFn)); } } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir index 7a83d465fd16..f993f61649a4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir @@ -137,3 +137,61 @@ util.func public @conv_strided(%arg0: tensor<1x16x16x4xf16>, %arg1: tensor<3x3x4 // CHECK: arith.addf // CHECK: } -> tensor<1x7x7x16xf32> // CHECK: util.return %[[MATMUL]] : tensor<1x7x7x16xf32> + +// ----- +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d3, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +util.func public @conv_nhwc_hwfc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x16x4xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<1x14x14x16xf32> + util.return %0 : tensor<1x14x14x16xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d3, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK: util.func public @conv_nhwc_hwfc( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<3x3x16x4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x14x14x16xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x14x14x9x4xf32> +// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col +// CHECK-SAME: m_offset = [0, 0] * [14, 1] k_offset = [0, 0] * [4, 1] +// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x14x14x9x4xf32>) -> tensor<1x14x14x9x4xf32> +// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]] : tensor<3x3x16x4xf32> into tensor<9x16x4xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x14x14x9x4xf32>, tensor<9x16x4xf32>) +// CHECK: util.return %[[MATMUL]] : tensor<1x14x14x16xf32> + +// ----- +util.func public @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + util.return %0 : tensor<1x14x14x16xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +// CHECK: util.func public @conv_2d_nhwc_fhwc( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<16x3x3x4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x14x14x16xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x14x14x36xf32> +// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col +// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x14x14x36xf32>, tensor<16x36xf32>) +// CHECK: util.return %[[MATMUL]] : tensor<1x14x14x16xf32> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel index 9196a939e7ff..bf240b8d83f3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel @@ -31,6 +31,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt index d9d69318c576..24eb1c5852c6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt @@ -27,6 +27,7 @@ iree_cc_library( MLIRArithDialect MLIRIR MLIRLinalgDialect + MLIRLinalgUtils MLIRMemRefDialect MLIRSupport MLIRTensorDialect diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index 1922bfb2a2f5..734d7c9c4555 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -8,16 +8,26 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" +#define DEBUG_TYPE "iree-linalgExt-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir::iree_compiler::IREE::LinalgExt { +static bool hasAllOneValues(ArrayRef attr) { + return llvm::all_of(attr, [](int64_t element) { return element == 1; }); +} + OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a, OpFoldResult b) { AffineExpr d0, d1; @@ -452,4 +462,216 @@ FailureOr> getIGEMMOperands(linalg::LinalgOp linalgOp) { .Default([](Operation *) { return failure(); }); } +FailureOr +getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { + + auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); + MLIRContext *ctx = linalgOp->getContext(); + if (failed(convDimsOrFailure)) + return failure(); + const mlir::linalg::ConvolutionDimensions &convDims = *convDimsOrFailure; + LLVM_DEBUG({ + llvm::dbgs() << "conv: " << linalgOp; + llvm::dbgs() << "\nconv batch dim: "; + llvm::interleaveComma(convDims.batch, llvm::dbgs()); + llvm::dbgs() << "\nconv output window dims: "; + llvm::interleaveComma(convDims.outputImage, llvm::dbgs()); + llvm::dbgs() << "\nconv output channel dim: "; + llvm::interleaveComma(convDims.outputChannel, llvm::dbgs()); + llvm::dbgs() << "\nconv filter window dims: "; + llvm::interleaveComma(convDims.filterLoop, llvm::dbgs()); + llvm::dbgs() << "\nconv input channel dims: "; + llvm::interleaveComma(convDims.inputChannel, llvm::dbgs()); + llvm::dbgs() << "\nconv depth multiplier: "; + llvm::interleaveComma(convDims.depth, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + Value input = linalgOp.getDpsInputs()[0]; + Value filter = linalgOp.getDpsInputs()[1]; + Value output = linalgOp.getDpsInits()[0]; + auto inputType = llvm::cast(input.getType()); + auto filterType = llvm::cast(filter.getType()); + auto outputType = llvm::cast(output.getType()); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + LDBG("[unimplemented] expected 'filterType' and 'inputType' to have static " + "shapes."); + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convDims.dilations)) { + LDBG("[unimplemented] expected no dilations (expected dilations to all be " + "one)."); + return failure(); + } + // TODO: Support depthwise. + if (!convDims.depth.empty()) { + LDBG("[unimplemented] expected no depth"); + return failure(); + } + + // TODO: Support pooling operations. For pooling ops, the input/output channel + // size will be categorized as the additional batch dimension. + if (convDims.outputChannel.empty() || convDims.inputChannel.empty()) { + LDBG("[unimplemented] expected no pooling operations"); + return failure(); + } + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + auto indexingMaps = linalgOp.getIndexingMapsArray(); + auto filterMap = indexingMaps[1]; + + SmallVector reductionDims; + for (auto iter : llvm::enumerate(linalgOp.getIteratorTypesArray())) { + if (linalg::isReductionIterator(iter.value())) { + reductionDims.push_back(iter.index()); + } + } + SmallVector filterkPos; + for (auto reductionDim : reductionDims) { + std::optional maybeDim = filterMap.getResultPosition( + getAffineDimExpr(reductionDim, filterMap.getContext())); + filterkPos.push_back(maybeDim.value()); + } + // group together adjacent reduction dimensions in the filter + SmallVector collapsedFilterReductionDim; + int64_t prevFilterIndex = filterkPos[0]; + int64_t currCollapsedIndex = 0; + collapsedFilterReductionDim.push_back({filterkPos[0]}); + SmallVector kShape = {filterShape[filterkPos[0]]}; + for (auto currPos : llvm::ArrayRef(filterkPos).drop_front()) { + if (prevFilterIndex == currPos - 1) { + collapsedFilterReductionDim[currCollapsedIndex].push_back(currPos); + } else { + collapsedFilterReductionDim.push_back({currPos}); + ++currCollapsedIndex; + } + prevFilterIndex = currPos; + } + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector filterIterators; + SmallVector filterNdims; + for (auto outputChannel : convDims.outputChannel) { + std::optional maybeDim = filterMap.getResultPosition( + getAffineDimExpr(outputChannel, filterMap.getContext())); + filterNdims.push_back(maybeDim.value()); + } + SmallVector filterReassocIndices; + // Interleave the parallel dims with the reduction dims. + int64_t filterNdimPos = 0; + for (auto collapsedDim : collapsedFilterReductionDim) { + for (int i = filterNdimPos; i < filterNdims.size(); i++) { + if (filterNdims[i] < collapsedDim[0]) { + filterReassocIndices.push_back({filterNdims[i]}); + filterIterators.push_back(parallel); + filterNdimPos = i + 1; + } else { + break; + } + } + filterIterators.push_back(reduction); + filterReassocIndices.push_back(collapsedDim); + } + // insert any leftover parallel dims in the end. + for (int i = filterNdimPos; i < filterNdims.size(); i++) { + filterReassocIndices.push_back({filterNdims[i]}); + filterIterators.push_back(parallel); + } + SmallVector reshapedFilterShape(filterReassocIndices.size(), 1); + for (auto [idx, indices] : llvm::enumerate(filterReassocIndices)) { + for (auto index : indices) { + reshapedFilterShape[idx] *= filterShape[index]; + } + } + + int64_t numBDims = (convDims.batch).size(); + int64_t numMDims = (convDims.outputImage).size(); + int64_t numNDims = (convDims.outputChannel).size(); + int64_t numParallelDims = numBDims + numMDims + numNDims; + int64_t numKDims = collapsedFilterReductionDim.size(); + SmallVector genericIterators(numParallelDims, parallel); + genericIterators.insert(genericIterators.end(), numKDims, reduction); + + SmallVector dims(numParallelDims + numKDims); + bindDimsList(ctx, dims); + auto resultMap = AffineMap::get( + numParallelDims + numKDims, 0, + SmallVector(dims.begin(), dims.begin() + numParallelDims), + ctx); + + bool isOutputChannelFirst = false; + auto outputChannelPos = convDims.outputChannel; + auto outputImagePos = convDims.outputImage; + if (outputChannelPos.back() < outputImagePos[0]) + isOutputChannelFirst = true; + + // prepare the input map. + SmallVector inputDims; + // Add the batch dimensions. + inputDims.insert(inputDims.end(), dims.begin(), dims.begin() + numBDims); + int64_t starting_m_pos = + isOutputChannelFirst ? numBDims + numNDims : numBDims; + // Add the M dims. + inputDims.insert(inputDims.end(), dims.begin() + starting_m_pos, + dims.begin() + starting_m_pos + numMDims); + // Add the reduction dims. + inputDims.insert(inputDims.end(), dims.begin() + numParallelDims, dims.end()); + auto inputMapGEMM = + AffineMap::get(numParallelDims + numKDims, 0, inputDims, ctx); + + // prepare filter map. + SmallVector filterDims; + int64_t curr_n_pos = isOutputChannelFirst ? numBDims : numBDims + numMDims; + int64_t curr_k_pos = numBDims + numMDims + numNDims; + + for (auto iter : filterIterators) { + if (iter == parallel) { + filterDims.push_back(dims[curr_n_pos++]); + } else if (iter == reduction) { + filterDims.push_back(dims[curr_k_pos++]); + } + } + auto filterMapGEMM = + AffineMap::get(numParallelDims + numKDims, 0, filterDims, ctx); + + SmallVector indexingGEMMMaps; + if (isOutputChannelFirst) { + indexingGEMMMaps.push_back(filterMapGEMM); + indexingGEMMMaps.push_back(inputMapGEMM); + } else { + indexingGEMMMaps.push_back(inputMapGEMM); + indexingGEMMMaps.push_back(filterMapGEMM); + } + indexingGEMMMaps.push_back(resultMap); + IGEMMGenericConvDetails igemmDetails; + igemmDetails.igemmContractionMaps = indexingGEMMMaps; + igemmDetails.igemmOperands = isOutputChannelFirst + ? SmallVector({filter, input}) + : SmallVector({input, filter}); + igemmDetails.igemmOperands.push_back(output); + SmallVector igemmLoopBounds; + igemmLoopBounds.insert(igemmLoopBounds.end(), outputShape.begin(), + outputShape.begin() + numParallelDims); + + SmallVector igemmLoopIterators(outputShape.size(), + parallel); + + for (auto iter : llvm::enumerate(filterIterators)) { + if (iter.value() == reduction) { + igemmLoopBounds.push_back(reshapedFilterShape[iter.index()]); + igemmLoopIterators.push_back(reduction); + } + } + igemmDetails.igemmLoopBounds = igemmLoopBounds; + igemmDetails.filterReassocIndices = filterReassocIndices; + igemmDetails.isOutputChannelFirst = isOutputChannelFirst; + igemmDetails.convDims = convDims; + igemmDetails.igemmLoopIterators = igemmLoopIterators; + + return igemmDetails; +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index c0609eb971e2..6357e2c33ff0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -8,6 +8,7 @@ #define IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -160,6 +161,37 @@ FailureOr> getIGEMMLoopBounds(linalg::LinalgOp linalgOp); /// layout, the order can be different (e.g., NCHW has the lhs and rhs swapped). FailureOr> getIGEMMOperands(linalg::LinalgOp linalgOp); +/// Struct that holds inferred IGEMM details for a convolution operation. +struct IGEMMGenericConvDetails { + /// The indexing maps array for a convolution operation with IGEMM + /// indexing. The resulting indexing maps represents the indexing of some + /// contraction that computes the equivalent IGEMM matmul of the convolution. + SmallVector igemmContractionMaps; + /// The loop bounds of a convolution op with IGEMM indexing. This + /// function assumes the same ordering of dimensions as + /// igemmContractionMaps; + SmallVector igemmLoopBounds; + /// The operand list for a convolution with IGEMM indexing. This is + /// used to determine which inputs are the lhs and rhs, since depending on the + /// layout, the order can be different (e.g., NCHW has the lhs and rhs + /// swapped). + SmallVector igemmOperands; + /// The inferred convolution dimensions. + mlir::linalg::ConvolutionDimensions convDims; + /// The reassociation indices used to computer the collapse shape of the + /// filter in IGEMM transformation. + SmallVector filterReassocIndices; + /// The iterator type list for a convolution with IGEMM indexing. . + SmallVector igemmLoopIterators; + /// Indicates if the OutputChannel is before the OutputImage in the output. + /// This determines our lhs/rhs ordering. + bool isOutputChannelFirst; +}; + +/// Populate `IGEMMGenericConvDetails` for a given convolution operation. +FailureOr +getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp); + /// Returns true if the operation increases bitwidths of tensors. /// This function checks that the genericOp: /// 1. Has only one output.