diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 81bb71c3052c..d90e18b90d36 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -31,6 +31,14 @@ namespace mlir::iree_compiler::IREE::GPU { +// TODO (nirvedhmeshram) : This flag allows a lot more convolutions to use IGEMM +// so drop this flag after sufficient use with no issues. +llvm::cl::opt clGPUUseTileAndFuseGenericConvolution( + "iree-gpu-use-tile-and-fuse-generic-convolution", + llvm::cl::desc( + "enable the tile and fuse pipeline for generic convolutions"), + llvm::cl::init(true)); + constexpr int64_t kCacheLineSizeBits = 128 * 8; constexpr int64_t kPreferredCopyNumBits = 128; @@ -371,12 +379,25 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target, return failure(); LDBG("IGEMM TileAndFuse Config"); - FailureOr> igemmContractionMaps = - LinalgExt::getIGEMMContractionIndexingMaps(linalgOp); - FailureOr> igemmLoopBounds = - LinalgExt::getIGEMMLoopBounds(linalgOp); - FailureOr> igemmOperands = - LinalgExt::getIGEMMOperands(linalgOp); + FailureOr> igemmContractionMaps; + FailureOr> igemmLoopBounds; + FailureOr> igemmOperands; + if (!clGPUUseTileAndFuseGenericConvolution) { + igemmContractionMaps = LinalgExt::getIGEMMContractionIndexingMaps(linalgOp); + igemmLoopBounds = LinalgExt::getIGEMMLoopBounds(linalgOp); + igemmOperands = LinalgExt::getIGEMMOperands(linalgOp); + } else { + FailureOr igemmGenericConvDetails = + LinalgExt::getIGEMMGenericConvDetails(linalgOp); + if (failed(igemmGenericConvDetails)) { + LDBG("Unsupported generic convolution type"); + return failure(); + } + igemmContractionMaps = igemmGenericConvDetails->igemmContractionMaps; + igemmLoopBounds = igemmGenericConvDetails->igemmLoopBounds; + igemmOperands = igemmGenericConvDetails->igemmOperands; + } + if (failed(igemmContractionMaps) || failed(igemmLoopBounds) || failed(igemmOperands)) { LDBG("Unsupported convolution type"); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp index 905252b31b3f..e37d3f9488f4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp @@ -17,9 +17,8 @@ namespace mlir::iree_compiler::IREE::LinalgExt { #define GEN_PASS_DEF_CONVERTCONV2DTOIM2COLOPPASS #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc" -static bool hasAllOneValues(DenseIntElementsAttr attr) { - return llvm::all_of( - attr, [](APInt element) { return element.getSExtValue() == 1; }); +static bool hasAllOneValues(ArrayRef attr) { + return llvm::all_of(attr, [](int64_t element) { return element == 1; }); } static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { @@ -36,12 +35,37 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { return builder.create(loc, x, y); } +// TODO : Upstream utility that does this pruning is broken for LinalgOp. Drop +// this if that gets fixed. +static SmallVector getPrunedAttributeList(linalg::LinalgOp op) { + const StringLiteral memoAttr = + linalg::LinalgDialect::kMemoizedIndexingMapsAttrName; + SmallVector prunedAttributeList; + for (auto attr : op->getDiscardableAttrs()) { + if (attr.getName() != memoAttr) { + prunedAttributeList.push_back(attr); + } + } + return prunedAttributeList; +} + +// Helper to convert a shape into basis for im2col op. +static SmallVector getBasisFromShape(ArrayRef shape) { + SmallVector basis(shape.size()); + int64_t cummulativeProduct = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + basis[i] = cummulativeProduct; + cummulativeProduct *= shape[i]; + } + return basis; +} + namespace { using ControlFnTy = std::function; - -// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) +// Converts non-depthwise convs into into linalg.generic (for img2col packing) // and linalg.matmul. +// The following explains this for a linalg.conv_2d_nhwc_hwcf op. // // A convolution operaton can be written as a matrix-matrix multiplication by // unfolding the cross correlation between input and filter and explicitly copy @@ -73,219 +97,141 @@ using ControlFnTy = std::function; // multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in // the N input. For the case where N > 1 its a batched matrxi-matrix // multplication. -class ConvertConv2DNhwcHwcf final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - ConvertConv2DNhwcHwcf(MLIRContext *context, - std::optional controlFn) - : OpRewritePattern(context), - controlFn(controlFn) {} +class ConvertConvGeneric final + : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + ConvertConvGeneric(MLIRContext *context, std::optional controlFn) + : OpInterfaceRewritePattern(context), controlFn(controlFn) {} + LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { - if (controlFn.has_value() && !controlFn.value()(convOp)) { - return rewriter.notifyMatchFailure(convOp, "controlFn failed."); + if (controlFn.has_value() && !controlFn.value()(linalgOp)) { + return rewriter.notifyMatchFailure(linalgOp, "controlFn failed."); } - auto inputType = llvm::cast(convOp.getInputs()[0].getType()); - auto filterType = llvm::cast(convOp.getInputs()[1].getType()); - auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); - - if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected 'filterType' and 'inputType' to have static shapes."; - }); + auto igemmConvDetailsOrFailure = + LinalgExt::getIGEMMGenericConvDetails(linalgOp); + if (failed(igemmConvDetailsOrFailure)) + return rewriter.notifyMatchFailure(linalgOp, + "Failed to extract IGEMM details"); + + LinalgExt::IGEMMGenericConvDetails igemmConvDetails = + *igemmConvDetailsOrFailure; + + SmallVector igemmContractionMaps = + igemmConvDetails.igemmContractionMaps; + mlir::linalg::ConvolutionDimensions convDims = igemmConvDetails.convDims; + SmallVector filterReassocIndices = + igemmConvDetails.filterReassocIndices; + bool isOutputChannelFirst = igemmConvDetails.isOutputChannelFirst; + SmallVector igemmLoopBounds = igemmConvDetails.igemmLoopBounds; + SmallVector igemmLoopIterators = + igemmConvDetails.igemmLoopIterators; + + Value input = linalgOp.getDpsInputs()[0]; + Value filter = linalgOp.getDpsInputs()[1]; + Value output = linalgOp.getDpsInits()[0]; + auto inputType = llvm::cast(input.getType()); + auto filterType = llvm::cast(filter.getType()); + auto outputType = llvm::cast(output.getType()); + + ArrayRef filterShape = filterType.getShape(); + ArrayRef outputShape = outputType.getShape(); + ArrayRef inputShape = inputType.getShape(); + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + AffineMap inputMap = indexingMaps[0]; + AffineMap filterMap = indexingMaps[1]; + AffineMap outputMap = indexingMaps[2]; + + SmallVector kernelSizes; + for (auto filterLoop : convDims.filterLoop) { + std::optional maybeDim = filterMap.getResultPosition( + getAffineDimExpr(filterLoop, filterMap.getContext())); + if (!maybeDim) { + return rewriter.notifyMatchFailure(linalgOp, + "Failed to infer filter shape."); + } + kernelSizes.push_back( + rewriter.getIndexAttr(filterShape[maybeDim.value()])); } - // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) { - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected no dilations (expected dilations to all be one)."; - }); + // Shape of the resulting tensor from im2col. + SmallVector colTensorShape; + SmallVector batchPos; + for (auto batch : convDims.batch) { + std::optional maybeBatch = inputMap.getResultPosition( + getAffineDimExpr(batch, inputMap.getContext())); + if (!maybeBatch) { + return rewriter.notifyMatchFailure(linalgOp, + "Failed to infer batch shape."); + } + batchPos.push_back(maybeBatch.value()); + colTensorShape.push_back(inputShape[maybeBatch.value()]); } - Value input = convOp.getInputs()[0]; - Value filter = convOp.getInputs()[1]; - Value output = convOp.getOutputs()[0]; - - auto filterShape = filterType.getShape(); - auto outputShape = outputType.getShape(); - - const int n = outputShape[0]; - const int oh = outputShape[1]; - const int ow = outputShape[2]; - const int oc = outputShape[3]; - const int fh = filterShape[0]; - const int fw = filterShape[1]; - const int ic = filterShape[2]; - - auto loc = convOp.getLoc(); - - SmallVector colTensorShape = {n, oh, ow, fh * fw * ic}; - - SmallVector outputReassocIndices = {{0}, {1, 2}, {3}}; - - Value colTensor = rewriter.create( - loc, colTensorShape, inputType.getElementType()); - SmallVector strides(convOp.getStrides().getValues()); - SmallVector dilations(convOp.getDilations().getValues()); - SmallVector kernelSize = {rewriter.getIndexAttr(fh), - rewriter.getIndexAttr(fw)}; - OpFoldResult zero = rewriter.getIndexAttr(0); - OpFoldResult one = rewriter.getIndexAttr(1); - SmallVector mOffset = {zero, zero}; - SmallVector mBasis = {rewriter.getIndexAttr(ow), one}; - SmallVector kOffset = {zero}; - SmallVector kBasis = {one}; - SmallVector batchPos = {0}; - SmallVector mPos = {1, 2}; - SmallVector kPos = {3}; - Value img2ColTensor = rewriter - .create( - loc, input, /*output=*/colTensor, strides, - dilations, kernelSize, mOffset, mBasis, - kOffset, kBasis, batchPos, mPos, kPos) - .getResult(0); - - SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; - auto reshapedFilterType = - RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); - - Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, filterReassocIndices); - - SmallVector indexingMaps = - getIGEMMContractionIndexingMaps(convOp).value(); - auto parallel = utils::IteratorType::parallel; - auto reduction = utils::IteratorType::reduction; - SmallVector genericIterators = { - parallel, parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create( - loc, outputType, - /*inputs=*/ValueRange{img2ColTensor, reshapedFilter}, - /*outputs=*/ValueRange{output}, indexingMaps, genericIterators, - [](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value lhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[0], - args[2].getType(), - /*isUnsignedCast=*/false); - Value rhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[1], - args[2].getType(), - /*isUnsignedCast=*/false); - Value mul = createMul(nestedLoc, lhs, rhs, nestedBuilder); - Value add = createAdd(nestedLoc, mul, args[2], nestedBuilder); - nestedBuilder.create(nestedLoc, add); - }, - linalg::getPrunedAttributeList(convOp)); - Value result = genericOp.getResults().front(); - - rewriter.replaceOp(convOp, result); - - return success(); - } - -private: - std::optional controlFn; -}; - -// For nchw, because the channels are to the left of the image shape dimensions, -// the position of the contraction dimension in the resulting matmul is -// reversed. This swaps the LHS and RHS of the matmul when compared with nhwc -// (i.e. (D, C x Kh x Kw) * (C x Kh x Kw, Ho x Wo)) -class ConvertConv2DNchwFchw final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - ConvertConv2DNchwFchw(MLIRContext *context, - std::optional controlFn) - : OpRewritePattern(context), - controlFn(controlFn) {} - - LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, - PatternRewriter &rewriter) const override { - if (controlFn.has_value() && !controlFn.value()(convOp)) { - return rewriter.notifyMatchFailure(convOp, "controlFn failed."); + SmallVector mPos; + SmallVector mShape; + for (auto outputImage : convDims.outputImage) { + for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) { + if (e.isFunctionOfDim(outputImage)) { + mPos.push_back(idx); + } + } + for (auto [idx, e] : llvm::enumerate(outputMap.getResults())) { + if (e.isFunctionOfDim(outputImage)) { + mShape.push_back(outputShape[idx]); + colTensorShape.push_back(outputShape[idx]); + } + } } - auto inputType = llvm::cast(convOp.getInputs()[0].getType()); - auto filterType = llvm::cast(convOp.getInputs()[1].getType()); - auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); - - if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected 'filterType' and 'inputType' to have static shapes."; - }); + SmallVector kPos; + for (auto reductionDim : convDims.inputChannel) { + for (auto [idx, e] : llvm::enumerate(inputMap.getResults())) { + if (e.isFunctionOfDim(reductionDim)) { + kPos.push_back(idx); + } + } } - - // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) - return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { - diag << "[unimplemented] " - << "expected no dilations (expected dilations to all be one)."; - }); - - Value input = convOp.getInputs()[0]; - Value filter = convOp.getInputs()[1]; - Value output = convOp.getOutputs()[0]; - - auto filterShape = filterType.getShape(); - auto outputShape = outputType.getShape(); - - const int n = outputShape[0]; - const int oc = outputShape[1]; - const int oh = outputShape[2]; - const int ow = outputShape[3]; - const int ic = filterShape[1]; - const int fh = filterShape[2]; - const int fw = filterShape[3]; - - auto loc = convOp.getLoc(); - - SmallVector colTensorShape = {n, oh, ow, fh * fw * ic}; - + // The index at which the reduction dimension bounds starts in + // igemmLoopBounds. + int64_t reductionBoundIndex = convDims.batch.size() + + convDims.outputImage.size() + + convDims.outputChannel.size(); + SmallVector kShape(igemmLoopBounds.begin() + reductionBoundIndex, + igemmLoopBounds.end()); + colTensorShape.insert(colTensorShape.end(), kShape.begin(), kShape.end()); + + SmallVector mBasis = + getAsIndexOpFoldResult(getContext(), getBasisFromShape(mShape)); + SmallVector kBasis = + getAsIndexOpFoldResult(getContext(), getBasisFromShape(kShape)); + + SmallVector kOffset(kBasis.size(), rewriter.getIndexAttr(0)); + SmallVector mOffset(mBasis.size(), rewriter.getIndexAttr(0)); + auto loc = linalgOp.getLoc(); Value colTensor = rewriter.create( loc, colTensorShape, inputType.getElementType()); - SmallVector strides(convOp.getStrides().getValues()); - SmallVector dilations(convOp.getDilations().getValues()); - SmallVector kernelSize = {rewriter.getIndexAttr(fh), - rewriter.getIndexAttr(fw)}; - OpFoldResult zero = rewriter.getIndexAttr(0); - OpFoldResult one = rewriter.getIndexAttr(1); - SmallVector mOffset = {zero, zero}; - SmallVector mBasis = {rewriter.getIndexAttr(ow), one}; - SmallVector kOffset = {zero}; - SmallVector kBasis = {one}; - SmallVector batchPos = {0}; - SmallVector mPos = {2, 3}; - SmallVector kPos = {1}; - Value img2ColTensor = rewriter - .create( - loc, input, /*output=*/colTensor, strides, - dilations, kernelSize, mOffset, mBasis, - kOffset, kBasis, batchPos, mPos, kPos) - .getResult(0); + Value img2ColTensor = + rewriter + .create( + loc, input, /*output=*/colTensor, convDims.strides, + convDims.dilations, kernelSizes, mOffset, mBasis, kOffset, + kBasis, batchPos, mPos, kPos) + .getResult(0); - SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; - auto reshapedFilterType = - RankedTensorType::get({oc, fh * fw * ic}, inputType.getElementType()); Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, filterReassocIndices); + loc, filter, filterReassocIndices); - SmallVector indexingMaps = - getIGEMMContractionIndexingMaps(convOp).value(); - auto parallel = utils::IteratorType::parallel; - auto reduction = utils::IteratorType::reduction; - SmallVector genericIterators = { - parallel, parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create( + auto genericGEMMOp = rewriter.create( loc, outputType, - /*inputs=*/ValueRange{reshapedFilter, img2ColTensor}, - /*outputs=*/ValueRange{output}, indexingMaps, genericIterators, + /*inputs=*/ + isOutputChannelFirst ? ValueRange{reshapedFilter, img2ColTensor} + : ValueRange{img2ColTensor, reshapedFilter}, + /*outputs=*/ValueRange{output}, igemmContractionMaps, + igemmLoopIterators, [](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { Value lhs = convertScalarToDtype(nestedBuilder, nestedLoc, args[0], args[2].getType(), @@ -296,12 +242,11 @@ class ConvertConv2DNchwFchw final Value mul = createMul(nestedLoc, lhs, rhs, nestedBuilder); Value add = createAdd(nestedLoc, mul, args[2], nestedBuilder); nestedBuilder.create(nestedLoc, add); - }, - linalg::getPrunedAttributeList(convOp)); - Value result = genericOp.getResults().front(); - - rewriter.replaceOp(convOp, result); + }); + genericGEMMOp->setDiscardableAttrs(getPrunedAttributeList(linalgOp)); + Value result = genericGEMMOp.getResults().front(); + rewriter.replaceOp(linalgOp, result); return success(); } @@ -327,8 +272,8 @@ struct ConvertConv2DToIm2ColOpPass final void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns, std::optional controlFn) { - patterns.insert( - patterns.getContext(), std::move(controlFn)); + patterns.insert(patterns.getContext(), + std::move(controlFn)); } } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir index 7a83d465fd16..f993f61649a4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv2d_to_im2col.mlir @@ -137,3 +137,61 @@ util.func public @conv_strided(%arg0: tensor<1x16x16x4xf16>, %arg1: tensor<3x3x4 // CHECK: arith.addf // CHECK: } -> tensor<1x7x7x16xf32> // CHECK: util.return %[[MATMUL]] : tensor<1x7x7x16xf32> + +// ----- +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d3, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +util.func public @conv_nhwc_hwfc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x16x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x16x4xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<1x14x14x16xf32> + util.return %0 : tensor<1x14x14x16xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d3, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK: util.func public @conv_nhwc_hwfc( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<3x3x16x4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x14x14x16xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x14x14x9x4xf32> +// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col +// CHECK-SAME: m_offset = [0, 0] * [14, 1] k_offset = [0, 0] * [4, 1] +// CHECK-SAME: ins(%[[ARG0]] : tensor<1x16x16x4xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x14x14x9x4xf32>) -> tensor<1x14x14x9x4xf32> +// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]] : tensor<3x3x16x4xf32> into tensor<9x16x4xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x14x14x9x4xf32>, tensor<9x16x4xf32>) +// CHECK: util.return %[[MATMUL]] : tensor<1x14x14x16xf32> + +// ----- +util.func public @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + util.return %0 : tensor<1x14x14x16xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> +// CHECK: util.func public @conv_2d_nhwc_fhwc( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x16x16x4xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<16x3x3x4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x14x14x16xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x14x14x36xf32> +// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col +// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : tensor<1x14x14x36xf32>, tensor<16x36xf32>) +// CHECK: util.return %[[MATMUL]] : tensor<1x14x14x16xf32> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel index 9196a939e7ff..bf240b8d83f3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/BUILD.bazel @@ -31,6 +31,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt index d9d69318c576..24eb1c5852c6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/CMakeLists.txt @@ -27,6 +27,7 @@ iree_cc_library( MLIRArithDialect MLIRIR MLIRLinalgDialect + MLIRLinalgUtils MLIRMemRefDialect MLIRSupport MLIRTensorDialect diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index 1922bfb2a2f5..734d7c9c4555 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -8,16 +8,26 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" +#define DEBUG_TYPE "iree-linalgExt-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir::iree_compiler::IREE::LinalgExt { +static bool hasAllOneValues(ArrayRef attr) { + return llvm::all_of(attr, [](int64_t element) { return element == 1; }); +} + OpFoldResult addOfrs(OpBuilder &builder, Location loc, OpFoldResult a, OpFoldResult b) { AffineExpr d0, d1; @@ -452,4 +462,216 @@ FailureOr> getIGEMMOperands(linalg::LinalgOp linalgOp) { .Default([](Operation *) { return failure(); }); } +FailureOr +getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { + + auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); + MLIRContext *ctx = linalgOp->getContext(); + if (failed(convDimsOrFailure)) + return failure(); + const mlir::linalg::ConvolutionDimensions &convDims = *convDimsOrFailure; + LLVM_DEBUG({ + llvm::dbgs() << "conv: " << linalgOp; + llvm::dbgs() << "\nconv batch dim: "; + llvm::interleaveComma(convDims.batch, llvm::dbgs()); + llvm::dbgs() << "\nconv output window dims: "; + llvm::interleaveComma(convDims.outputImage, llvm::dbgs()); + llvm::dbgs() << "\nconv output channel dim: "; + llvm::interleaveComma(convDims.outputChannel, llvm::dbgs()); + llvm::dbgs() << "\nconv filter window dims: "; + llvm::interleaveComma(convDims.filterLoop, llvm::dbgs()); + llvm::dbgs() << "\nconv input channel dims: "; + llvm::interleaveComma(convDims.inputChannel, llvm::dbgs()); + llvm::dbgs() << "\nconv depth multiplier: "; + llvm::interleaveComma(convDims.depth, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + Value input = linalgOp.getDpsInputs()[0]; + Value filter = linalgOp.getDpsInputs()[1]; + Value output = linalgOp.getDpsInits()[0]; + auto inputType = llvm::cast(input.getType()); + auto filterType = llvm::cast(filter.getType()); + auto outputType = llvm::cast(output.getType()); + + if (!filterType.hasStaticShape() || !inputType.hasStaticShape()) { + LDBG("[unimplemented] expected 'filterType' and 'inputType' to have static " + "shapes."); + return failure(); + } + + // TODO: Support dilation. + if (!hasAllOneValues(convDims.dilations)) { + LDBG("[unimplemented] expected no dilations (expected dilations to all be " + "one)."); + return failure(); + } + // TODO: Support depthwise. + if (!convDims.depth.empty()) { + LDBG("[unimplemented] expected no depth"); + return failure(); + } + + // TODO: Support pooling operations. For pooling ops, the input/output channel + // size will be categorized as the additional batch dimension. + if (convDims.outputChannel.empty() || convDims.inputChannel.empty()) { + LDBG("[unimplemented] expected no pooling operations"); + return failure(); + } + auto filterShape = filterType.getShape(); + auto outputShape = outputType.getShape(); + auto indexingMaps = linalgOp.getIndexingMapsArray(); + auto filterMap = indexingMaps[1]; + + SmallVector reductionDims; + for (auto iter : llvm::enumerate(linalgOp.getIteratorTypesArray())) { + if (linalg::isReductionIterator(iter.value())) { + reductionDims.push_back(iter.index()); + } + } + SmallVector filterkPos; + for (auto reductionDim : reductionDims) { + std::optional maybeDim = filterMap.getResultPosition( + getAffineDimExpr(reductionDim, filterMap.getContext())); + filterkPos.push_back(maybeDim.value()); + } + // group together adjacent reduction dimensions in the filter + SmallVector collapsedFilterReductionDim; + int64_t prevFilterIndex = filterkPos[0]; + int64_t currCollapsedIndex = 0; + collapsedFilterReductionDim.push_back({filterkPos[0]}); + SmallVector kShape = {filterShape[filterkPos[0]]}; + for (auto currPos : llvm::ArrayRef(filterkPos).drop_front()) { + if (prevFilterIndex == currPos - 1) { + collapsedFilterReductionDim[currCollapsedIndex].push_back(currPos); + } else { + collapsedFilterReductionDim.push_back({currPos}); + ++currCollapsedIndex; + } + prevFilterIndex = currPos; + } + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector filterIterators; + SmallVector filterNdims; + for (auto outputChannel : convDims.outputChannel) { + std::optional maybeDim = filterMap.getResultPosition( + getAffineDimExpr(outputChannel, filterMap.getContext())); + filterNdims.push_back(maybeDim.value()); + } + SmallVector filterReassocIndices; + // Interleave the parallel dims with the reduction dims. + int64_t filterNdimPos = 0; + for (auto collapsedDim : collapsedFilterReductionDim) { + for (int i = filterNdimPos; i < filterNdims.size(); i++) { + if (filterNdims[i] < collapsedDim[0]) { + filterReassocIndices.push_back({filterNdims[i]}); + filterIterators.push_back(parallel); + filterNdimPos = i + 1; + } else { + break; + } + } + filterIterators.push_back(reduction); + filterReassocIndices.push_back(collapsedDim); + } + // insert any leftover parallel dims in the end. + for (int i = filterNdimPos; i < filterNdims.size(); i++) { + filterReassocIndices.push_back({filterNdims[i]}); + filterIterators.push_back(parallel); + } + SmallVector reshapedFilterShape(filterReassocIndices.size(), 1); + for (auto [idx, indices] : llvm::enumerate(filterReassocIndices)) { + for (auto index : indices) { + reshapedFilterShape[idx] *= filterShape[index]; + } + } + + int64_t numBDims = (convDims.batch).size(); + int64_t numMDims = (convDims.outputImage).size(); + int64_t numNDims = (convDims.outputChannel).size(); + int64_t numParallelDims = numBDims + numMDims + numNDims; + int64_t numKDims = collapsedFilterReductionDim.size(); + SmallVector genericIterators(numParallelDims, parallel); + genericIterators.insert(genericIterators.end(), numKDims, reduction); + + SmallVector dims(numParallelDims + numKDims); + bindDimsList(ctx, dims); + auto resultMap = AffineMap::get( + numParallelDims + numKDims, 0, + SmallVector(dims.begin(), dims.begin() + numParallelDims), + ctx); + + bool isOutputChannelFirst = false; + auto outputChannelPos = convDims.outputChannel; + auto outputImagePos = convDims.outputImage; + if (outputChannelPos.back() < outputImagePos[0]) + isOutputChannelFirst = true; + + // prepare the input map. + SmallVector inputDims; + // Add the batch dimensions. + inputDims.insert(inputDims.end(), dims.begin(), dims.begin() + numBDims); + int64_t starting_m_pos = + isOutputChannelFirst ? numBDims + numNDims : numBDims; + // Add the M dims. + inputDims.insert(inputDims.end(), dims.begin() + starting_m_pos, + dims.begin() + starting_m_pos + numMDims); + // Add the reduction dims. + inputDims.insert(inputDims.end(), dims.begin() + numParallelDims, dims.end()); + auto inputMapGEMM = + AffineMap::get(numParallelDims + numKDims, 0, inputDims, ctx); + + // prepare filter map. + SmallVector filterDims; + int64_t curr_n_pos = isOutputChannelFirst ? numBDims : numBDims + numMDims; + int64_t curr_k_pos = numBDims + numMDims + numNDims; + + for (auto iter : filterIterators) { + if (iter == parallel) { + filterDims.push_back(dims[curr_n_pos++]); + } else if (iter == reduction) { + filterDims.push_back(dims[curr_k_pos++]); + } + } + auto filterMapGEMM = + AffineMap::get(numParallelDims + numKDims, 0, filterDims, ctx); + + SmallVector indexingGEMMMaps; + if (isOutputChannelFirst) { + indexingGEMMMaps.push_back(filterMapGEMM); + indexingGEMMMaps.push_back(inputMapGEMM); + } else { + indexingGEMMMaps.push_back(inputMapGEMM); + indexingGEMMMaps.push_back(filterMapGEMM); + } + indexingGEMMMaps.push_back(resultMap); + IGEMMGenericConvDetails igemmDetails; + igemmDetails.igemmContractionMaps = indexingGEMMMaps; + igemmDetails.igemmOperands = isOutputChannelFirst + ? SmallVector({filter, input}) + : SmallVector({input, filter}); + igemmDetails.igemmOperands.push_back(output); + SmallVector igemmLoopBounds; + igemmLoopBounds.insert(igemmLoopBounds.end(), outputShape.begin(), + outputShape.begin() + numParallelDims); + + SmallVector igemmLoopIterators(outputShape.size(), + parallel); + + for (auto iter : llvm::enumerate(filterIterators)) { + if (iter.value() == reduction) { + igemmLoopBounds.push_back(reshapedFilterShape[iter.index()]); + igemmLoopIterators.push_back(reduction); + } + } + igemmDetails.igemmLoopBounds = igemmLoopBounds; + igemmDetails.filterReassocIndices = filterReassocIndices; + igemmDetails.isOutputChannelFirst = isOutputChannelFirst; + igemmDetails.convDims = convDims; + igemmDetails.igemmLoopIterators = igemmLoopIterators; + + return igemmDetails; +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index c0609eb971e2..6357e2c33ff0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -8,6 +8,7 @@ #define IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -160,6 +161,37 @@ FailureOr> getIGEMMLoopBounds(linalg::LinalgOp linalgOp); /// layout, the order can be different (e.g., NCHW has the lhs and rhs swapped). FailureOr> getIGEMMOperands(linalg::LinalgOp linalgOp); +/// Struct that holds inferred IGEMM details for a convolution operation. +struct IGEMMGenericConvDetails { + /// The indexing maps array for a convolution operation with IGEMM + /// indexing. The resulting indexing maps represents the indexing of some + /// contraction that computes the equivalent IGEMM matmul of the convolution. + SmallVector igemmContractionMaps; + /// The loop bounds of a convolution op with IGEMM indexing. This + /// function assumes the same ordering of dimensions as + /// igemmContractionMaps; + SmallVector igemmLoopBounds; + /// The operand list for a convolution with IGEMM indexing. This is + /// used to determine which inputs are the lhs and rhs, since depending on the + /// layout, the order can be different (e.g., NCHW has the lhs and rhs + /// swapped). + SmallVector igemmOperands; + /// The inferred convolution dimensions. + mlir::linalg::ConvolutionDimensions convDims; + /// The reassociation indices used to computer the collapse shape of the + /// filter in IGEMM transformation. + SmallVector filterReassocIndices; + /// The iterator type list for a convolution with IGEMM indexing. . + SmallVector igemmLoopIterators; + /// Indicates if the OutputChannel is before the OutputImage in the output. + /// This determines our lhs/rhs ordering. + bool isOutputChannelFirst; +}; + +/// Populate `IGEMMGenericConvDetails` for a given convolution operation. +FailureOr +getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp); + /// Returns true if the operation increases bitwidths of tensors. /// This function checks that the genericOp: /// 1. Has only one output.