From 7e6c5ec26332c55d1912d3f7fd2cf2350a936768 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Mon, 27 Jan 2025 14:06:07 -0800 Subject: [PATCH] Switch to upstream StablehloToLinalg code. (#19792) While looking at compiler warnings in build logs, I noticed paths in StableHLO that looked out of place. As it turns out, much of IREE's StableHLO to Linalg conversion code was forked into upstream StableHLO in https://github.com/openxla/stablehlo/pull/1817, though there have been some local changes to the code here since it was forked: https://github.com/iree-org/iree/commits/main/compiler/plugins/input/StableHLO/Conversion. Switching to use the upstream code will allow us to decrease the surface area we directly support and limit the number of files we need to build from source, but it will also make maintenance require coordinating more with upstream (such as during LLVM integrates). We still point to a fork at https://github.com/iree-org/stablehlo , so if things get tricky we can choose to set up a branch with patches as needed. Some notes: * More code, particularly includes and build dependencies, could be pruned. * We can probably delete more code by reviving https://github.com/iree-org/iree/pull/18681 too * I deleted lit tests for the patterns that were moved upstream. The tests still exist at https://github.com/openxla/stablehlo/tree/main/stablehlo/conversions/linalg/tests, but I don't see much value in having our own versions of the lit tests. We do still have e2e tests that compile and run. * I did _not_ plumb through the `enablePrimitiveOps` or `enableSparseOps` options, which may be useful for some programs * I'm keeping our custom `stablehlo.concatenate` lowering since the alternate lowering (from IREE or now StableHLO) has correctness issues. Also keeping the FFT lowering since that does not exist upstream and it handles cases that our LinalgExt lowering does not. --- .../bazel_to_cmake/bazel_to_cmake_targets.py | 3 + .../input/StableHLO/Conversion/BUILD.bazel | 10 +- .../input/StableHLO/Conversion/CMakeLists.txt | 10 +- .../Conversion/LegalizeToLinalgUtils.cpp | 129 - .../Conversion/LegalizeToLinalgUtils.h | 76 - .../input/StableHLO/Conversion/Passes.cpp | 1 - .../input/StableHLO/Conversion/Passes.h | 9 +- .../input/StableHLO/Conversion/Passes.td | 9 - .../input/StableHLO/Conversion/Rewriters.h | 53 - .../StableHLO/Conversion/StableHLOToArith.cpp | 146 - .../StableHLOToIREEInputDialects.cpp | 50 +- .../Conversion/StableHLOToLinalg.cpp | 2683 ----------------- .../StableHLOToLinalgConvolution.cpp | 807 ----- .../Conversion/StableHLOToLinalgDotProd.cpp | 291 -- .../Conversion/StableHLOToLinalgPointwise.cpp | 316 -- .../Conversion/StableHLOToLinalgRandom.cpp | 917 ------ .../Conversion/StableHLOToLinalgReduce.cpp | 723 ----- .../StableHLO/Conversion/TypeConversion.cpp | 93 - .../StableHLO/Conversion/TypeConversion.h | 34 - .../StableHLO/Conversion/test/BUILD.bazel | 7 - .../StableHLO/Conversion/test/CMakeLists.txt | 7 - .../Conversion/test/stablehlo_to_linalg.mlir | 1618 ---------- .../test/stablehlo_to_linalg_convolution.mlir | 595 ---- .../test/stablehlo_to_linalg_dot_prod.mlir | 276 -- .../test/stablehlo_to_linalg_gather.mlir | 325 -- .../test/stablehlo_to_linalg_pointwise.mlir | 1489 --------- .../test/stablehlo_to_linalg_random.mlir | 693 ----- .../test/stablehlo_to_linalg_reduce.mlir | 794 ----- 28 files changed, 31 insertions(+), 12133 deletions(-) delete mode 100644 compiler/plugins/input/StableHLO/Conversion/StableHLOToArith.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgDotProd.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgPointwise.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgRandom.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp delete mode 100644 compiler/plugins/input/StableHLO/Conversion/TypeConversion.h delete mode 100644 compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg.mlir delete mode 100644 compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_convolution.mlir delete mode 100644 compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_dot_prod.mlir delete mode 100644 compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_gather.mlir delete mode 100644 compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_pointwise.mlir delete mode 100644 compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_random.mlir delete mode 100644 compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_reduce.mlir diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index 0c0469eb335c..0d284e9c6205 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -90,6 +90,9 @@ def __init__(self, repo_map: Dict[str, str]): "@stablehlo//:stablehlo_passes": [ "StablehloPasses", ], + "@stablehlo//:linalg_passes": [ + "StablehloLinalgTransforms", + ], "@stablehlo//:vhlo_ops": [ "VhloOps", ], diff --git a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel index 110c27ae4eac..768cc48b2d83 100644 --- a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel @@ -70,17 +70,8 @@ iree_compiler_cc_library( "LegalizeToLinalgUtils.h", "MapStableHLOToScalarOp.h", "StableHLOCustomCalls.cpp", - "StableHLOToArith.cpp", "StableHLOToIREEInputDialects.cpp", - "StableHLOToLinalg.cpp", - "StableHLOToLinalgConvolution.cpp", - "StableHLOToLinalgDotProd.cpp", "StableHLOToLinalgExt.cpp", - "StableHLOToLinalgPointwise.cpp", - "StableHLOToLinalgRandom.cpp", - "StableHLOToLinalgReduce.cpp", - "TypeConversion.cpp", - "TypeConversion.h", "VerifyCompilerInputLegality.cpp", ], deps = [ @@ -121,6 +112,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:VectorDialect", "@stablehlo//:broadcast_utils", "@stablehlo//:chlo_ops", + "@stablehlo//:linalg_passes", "@stablehlo//:stablehlo_ops", "@stablehlo//:vhlo_ops", ], diff --git a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt index 5afe3a28fdf4..e1e76357b59b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt +++ b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt @@ -56,17 +56,8 @@ iree_cc_library( "LegalizeToLinalgUtils.h" "MapStableHLOToScalarOp.h" "StableHLOCustomCalls.cpp" - "StableHLOToArith.cpp" "StableHLOToIREEInputDialects.cpp" - "StableHLOToLinalg.cpp" - "StableHLOToLinalgConvolution.cpp" - "StableHLOToLinalgDotProd.cpp" "StableHLOToLinalgExt.cpp" - "StableHLOToLinalgPointwise.cpp" - "StableHLOToLinalgRandom.cpp" - "StableHLOToLinalgReduce.cpp" - "TypeConversion.cpp" - "TypeConversion.h" "VerifyCompilerInputLegality.cpp" DEPS ::CHLODecompositionPatterns @@ -99,6 +90,7 @@ iree_cc_library( MLIRTransforms MLIRVectorDialect StablehloBroadcastUtils + StablehloLinalgTransforms StablehloOps VhloOps iree::compiler::Dialect::Flow::IR diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp index 54bf385bb7a5..4520409dee30 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp @@ -13,21 +13,10 @@ #include #include -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" -#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" namespace mlir::iree_compiler::stablehlo { -namespace { -bool hasIntegralShapeType(Operation *op) { - auto stp = llvm::dyn_cast(op->getOperand(0).getType()); - return stp && stp.getElementType().isIntOrIndex(); -} - -} // namespace SmallVector getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) { @@ -42,124 +31,6 @@ getNParallelLoopsAttrs(unsigned nParallelLoops) { return getParallelAndReductionIterators(nParallelLoops, 0); } -Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes) { - return b.create( - loc, llvm::cast(type), dynSizes, - /*copy=*/Value(), - /*memory_space=*/IntegerAttr()); -} - -Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes) { - return b.create( - loc, type.getShape(), type.getElementType(), dynSizes, - llvm::cast(type).getEncoding()); -} - -Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType, - Operation *op, ValueRange operands) { - bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr; - // Collect the sizes for a ranked tensor to be passed as parameter to a - // new tensor initialization operation. This operation only needs the - // dynamic sizes. - SmallVector sizes; - if (resultType.hasRank() && !resultType.hasStaticShape()) { - // Ask the op for its output shape. - auto shapeSource = cast(op); - SmallVector reifiedShapes; - (void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes); - assert(reifiedShapes.size() == 1 && "Expected one reified result"); - // Construct sizes for the required dimensions. - for (const auto &en : llvm::enumerate(resultType.getShape())) { - if (!ShapedType::isDynamic(en.value())) - continue; - sizes.push_back(b.create( - loc, reifiedShapes[0], - ValueRange{b.create(loc, en.index())})); - } - } - return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes) - : getEmptyTensor(b, loc, resultType, sizes); -} - -Value coerceTensorShape(OpBuilder &builder, Location loc, - TypedValue value, ShapedType targetType) { - return builder.createOrFold( - loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()), - value); -} - -LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op) { - auto isRankedTensor = [](Value val) { - return isa(val.getType()); - }; - if (!llvm::all_of(op->getOperands(), isRankedTensor)) - return failure(); - return success(llvm::all_of(op->getResults(), isRankedTensor)); -} - -Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor) { - auto type = cast(tensor.getType()); - Value zero; - // Complex numbers are a special case. - if (auto complexType = llvm::dyn_cast(type.getElementType())) { - auto zeroElement = builder.getZeroAttr(complexType.getElementType()); - auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement}); - zero = builder.create(loc, complexType, zeroAttr); - } else { - auto zeroAttr = builder.getZeroAttr(type.getElementType()); - zero = builder.create(loc, zeroAttr); - } - return builder.create(loc, zero, tensor).result(); -} - -Value preSparsify(Operation *op, llvm::SmallVector &values, Type rtp, - OpBuilder *b) { - // Apply for semi-ring operations that lower to elaborate code - // (any sign-op, or an integral abs-op). - // TODO(peiming, ajcbik): these all can potentially be optimized by applying - // value transform on sparse_tenosr.value memref - if (isa(op) || - (isa(op) && hasIntegralShapeType(op)) || - isa(op)) { - if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) && - !sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType())) - return Value(); - Location loc = op->getLoc(); - auto semiring = b->create(loc, rtp, values[0]); - Type itp = values[0].getType(); - Block *present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc); - b->setInsertionPointToStart(&semiring.getPresentRegion().front()); - values[0] = present->getArgument(0); - return semiring; - } - return Value(); -} - -Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b) { - if (semiring) { - b->create(op->getLoc(), result); - b->setInsertionPointAfter(semiring.getDefiningOp()); - return semiring; - } - return result; -} - -bool allOperandsAreScalarTensors(Operation *op) { - return llvm::all_of(op->getOperands(), [](Value operand) { - auto operandTy = llvm::dyn_cast(operand.getType()); - return operandTy && operandTy.getRank() == 0; - }); -} - -bool isInBodyOfLinalgOps(Operation *op) { - auto *parentOp = op->getParentRegion()->getParentOp(); - return parentOp->getDialect() == - parentOp->getContext()->getLoadedDialect(); -} - SmallVector extract1DVector(DenseIntElementsAttr elements) { SmallVector ret; for (const APInt &element : elements) { diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h index 2b74488c9ef3..5bbe07fdb4fe 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h @@ -15,26 +15,16 @@ #include #include -#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" namespace mlir::iree_compiler::stablehlo { @@ -49,75 +39,9 @@ getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction); SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops); -/// Generates an init sparse tensor. -Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes); - -/// Generates a tensor.empty op. -Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes); - -/// Generates an empty tensor for the result of the operation, which could be a -/// dense tensor or a sparse tensor. -Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType, - Operation *op, ValueRange operands); - -/// Ensures a tensor has the same shape (not including the element type) as -/// another. -Value coerceTensorShape(OpBuilder &builder, Location loc, - TypedValue value, ShapedType targetType); - -/// Verifies |op|'s semantics by checking if all operands and results have -/// ranged tensor types. -LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op); - -/// Fills |tensor| with a zero constant of the matching type. Returns the new -/// value. -Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor); - -/// Sparsifies a (block of) operation(s) that cannot be handled directly -/// by the sparse compiler but has well-known semi-ring semantics. -/// -/// This yields something of the following form: -/// -/// %result = sparse_tensor.unary %values[0] -/// present={ -/// ^bb1(%val): -/// ... codegen proceeds here using %val .... -/// sparse_tensor.yield -/// } -/// absent={} -/// linalg.yield %result -Value preSparsify(Operation *op, llvm::SmallVector &values, Type rtp, - OpBuilder *b); - -/// Finalizes sparse semi-ring construction. -Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b); - -/// Returns true if all operands are tensors with rank 0. -bool allOperandsAreScalarTensors(Operation *op); - -/// Returns true if parent op is linalg. -bool isInBodyOfLinalgOps(Operation *op); - /// Extracts integer values from the attribute |elements|. SmallVector extract1DVector(DenseIntElementsAttr elements); -/// Returns true if the given |values| is a splat of the given |queryValue|. -inline bool isSplatValue(const ArrayRef &values, int64_t queryValue) { - for (auto value : values) { - if (value != queryValue) { - return false; - } - } - return true; -} - -/// Returns true if the given |attr| is a splat of the given |value|. -inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) { - return attr.isSplat() && attr.getSplatValue() == value; -} - } // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_ diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp index d8a0e66f3cdd..b66efc1f2083 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp @@ -87,7 +87,6 @@ void buildStableHLOInputConversionPassPipelineImpl( stablehlo::createConvertStableHloToLinalgExt()); passManager.addNestedPass(stablehlo::createLegalizeChlo()); passManager.addPass(createConvertStableHloToIreeInputDialects()); - // Ensure conversion completed. passManager.addPass(createReconcileUnrealizedCastsPass()); // Note that some StableHLO ops are left by the above and must resolve via diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.h b/compiler/plugins/input/StableHLO/Conversion/Passes.h index 4b63daeb6905..e6bd1741f4be 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Passes.h +++ b/compiler/plugins/input/StableHLO/Conversion/Passes.h @@ -10,11 +10,7 @@ #include "compiler/plugins/input/StableHLO/Conversion/PassDetail.h" #include "mlir/Pass/Pass.h" -namespace mlir { -class TypeConverter; -namespace iree_compiler::stablehlo { - -std::unique_ptr createStableHloToLinalgTypeConverter(); +namespace mlir::iree_compiler::stablehlo { struct StableHloOptions : public PassPipelineOptions {}; @@ -36,7 +32,6 @@ void buildStableHLOXLAInputConversionPassPipeline( void registerStableHLOConversionPasses(); -} // namespace iree_compiler::stablehlo -} // namespace mlir +} // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_PASSES_H_ diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.td b/compiler/plugins/input/StableHLO/Conversion/Passes.td index 8f7aa2b1ec3d..7852112d453b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Passes.td +++ b/compiler/plugins/input/StableHLO/Conversion/Passes.td @@ -32,15 +32,6 @@ def ConvertStableHloToLinalgExt : // General passes //===----------------------------------------------------------------------===// -def ConvertStableHloToLinalg : - Pass<"iree-stablehlo-to-linalg", "ModuleOp"> { - let summary = "Converts from StableHLO ops to Linalg ops on"; - let options = [Option<"enablePrimitiveOps", "enable-primitive-ops", "bool", - /*default=*/"false", - "Lower to primitive Linalg ops (map, reduce and " - "transpose) when possible, instead of linalg.generic">]; -} - def LegalizeControlFlow : InterfacePass<"iree-stablehlo-legalize-control-flow", "mlir::FunctionOpInterface"> { let summary = "Legalizes from StableHLO control flow to SCF control flow"; diff --git a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h index 8a4380284f52..1a328ff3908b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h +++ b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h @@ -15,12 +15,6 @@ namespace mlir::iree_compiler::stablehlo { // General StableHLO/CHLO lowering patterns. //===----------------------------------------------------------------------===// -/// Populates the patterns that convert from StableHLO to Linalg on tensors. -void populateStableHloToLinalgConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet *patterns, - bool enablePrimitiveOps); - /// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and /// Shape ops. void populateLegalizeChloPatterns(MLIRContext *context, @@ -44,59 +38,12 @@ void populateStableHloToLinalgExtConversionPatterns( MLIRContext *context, TypeConverter &typeConverter, RewritePatternSet *patterns); -/// Populates the patterns that convert from StableHLO to Linalg on tensors. -/// Extends the general linalg lowering patterns with IREE-specific ones. -void populateStableHloToLinalgOnTensorsConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - /// Populates the patterns that convert from StableHLO collective ops to Flow /// ops. void populateStableHloCollectivesConversionPatterns( MLIRContext *context, TypeConverter &typeConverter, RewritePatternSet *patterns); -//===----------------------------------------------------------------------===// -// Fine-grained patterns used by the implementation. -//===----------------------------------------------------------------------===// -namespace detail { -/// Populates the patterns that convert from elementwise StableHLO ops to Linalg -/// on tensors. -void populatePointwiseStableHloToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps); - -/// Populates the patterns that convert from convolution StableHLO ops to Linalg -/// on tensors. -void populateStableHloConvolutionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - -/// Populates the patterns that convert from dot product StableHLO ops to Linalg -/// on tensors. -void populateStableHloDotProdToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - -/// Populates the patterns that convert from random number generation StableHLO -/// ops to Linalg on tensors. -void populateStableHloRandomToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - -/// Populates the patterns that convert from reduction StableHLO ops to Linalg -/// on tensors. -void populateStableHloReductionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps); - -/// Populates the patterns that convert scalar StableHLO ops to Arith ops. -void populateScalarHloToArithConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, - llvm::function_ref filterFn = nullptr); -} // namespace detail - } // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_REWRITERS_H_ diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToArith.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToArith.cpp deleted file mode 100644 index 7a62ec2aca80..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToArith.cpp +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering scalar StableHLO ops to arith dialect. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { - -template -struct ScalarHloToFuncPatterns final : OpConversionPattern { - ScalarHloToFuncPatterns(TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op->getParentOp())) { - return rewriter.notifyMatchFailure(op, - "Return must be inside a function"); - } - mlir::Operation::operand_range operands = op.getOperands(); - rewriter.replaceOpWithNewOp(op, operands); - return success(); - } -}; -template -struct ScalarHloToArithmeticPattern final : OpConversionPattern { - ScalarHloToArithmeticPattern( - TypeConverter &typeConverter, MLIRContext *context, - llvm::function_ref filterFn = nullptr, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - filterFn(filterFn) {} - - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (filterFn && !filterFn(op)) - return failure(); - - auto isScalar = [](Value v) { - return cast(v.getType()).getRank() == 0; - }; - - if (!llvm::all_of(adaptor.getOperands(), isScalar)) - return rewriter.notifyMatchFailure(op, "All operands must be scalar."); - - Location loc = op.getLoc(); - - auto resultTy = dyn_cast_or_null( - this->getTypeConverter()->convertType(op->getResultTypes().front())); - if (!resultTy) - return failure(); - - SmallVector operands; - for (Value operand : adaptor.getOperands()) { - operands.push_back( - rewriter.create(loc, operand, ValueRange())); - } - Value scalarResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, resultTy.getElementType(), operands, &rewriter); - if (!scalarResult) - return failure(); - rewriter.replaceOpWithNewOp(op, resultTy, - scalarResult); - return success(); - } - -private: - llvm::function_ref filterFn; -}; - -} // namespace - -namespace detail { -void populateScalarHloToArithConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, - llvm::function_ref filterFn) { - // TODO(#12678): Handle the XLA rng op. - patterns->add< - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern>(typeConverter, - context, filterFn); - patterns->add>( - typeConverter, context); -} -} // namespace detail - -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp index 65963904b0b7..bce2f06187e6 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp @@ -12,7 +12,6 @@ #include "compiler/plugins/input/StableHLO/Conversion/Passes.h" #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h" #include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" @@ -35,6 +34,8 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/conversions/linalg/transforms/Rewriters.h" +#include "stablehlo/conversions/linalg/transforms/TypeConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -46,7 +47,9 @@ namespace mlir::iree_compiler::stablehlo { namespace { /// Converts stablehlo.concatenate operation to extract_slice ops + insert_slice -/// ops. +/// ops. mlir::stablehlo::populateStablehloToLinalgConversionPatterns provides a +/// lowering to linalg using SCF that has numerics issues when run through IREE, +/// so we use this lowering instead with a higher pattern benefit. struct ConcatenateOpConversion final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -491,10 +494,11 @@ struct ConvertStableHloToIreeInputDialects final : impl::ConvertStableHloToIreeInputDialectsBase< ConvertStableHloToIreeInputDialects> { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert< - IREE::Flow::FlowDialect, IREE::Util::UtilDialect, linalg::LinalgDialect, - arith::ArithDialect, tensor::TensorDialect, shape::ShapeDialect, - math::MathDialect, memref::MemRefDialect, complex::ComplexDialect>(); + registry.insert(); } void runOnOperation() override { @@ -502,7 +506,7 @@ struct ConvertStableHloToIreeInputDialects final RewritePatternSet patterns(context); std::unique_ptr typeConverter = - createStableHloToLinalgTypeConverter(); + std::make_unique<::mlir::stablehlo::LinalgTypeConverter>(); typeConverter->addArgumentMaterialization(scalarToTensor); typeConverter->addSourceMaterialization(scalarToTensor); typeConverter->addTargetMaterialization(scalarToTensor); @@ -511,16 +515,23 @@ struct ConvertStableHloToIreeInputDialects final // expensive expansions. populateCanonicalizationPatterns(context, &patterns, /*benefit=*/1024); - populateStableHloToLinalgOnTensorsConversionPatterns( - context, *typeConverter, &patterns); + // Run custom patterns with a high benefit to override stablehlo patterns. + patterns.add(*typeConverter, context, + PatternBenefit{1000}); + + // Run upstream stablehlo patterns with a default benefit. + ::mlir::stablehlo::populateStablehloToLinalgConversionPatterns( + context, *typeConverter, &patterns, /*enablePrimitiveOps=*/false, + /*enableSparseOps=*/false); + + // Lowerings using IREE-specific operators (and not just common dialects + // like linalg, scf, arith, etc.). populateStableHloCollectivesConversionPatterns(context, *typeConverter, &patterns); - // TODO(#12678): Handle remaining complex ops. - // TODO(*): expose patterns that do this much better from - // iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp - + // iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp // Structural patterns (functions, cfg, terminators). patterns.add(*typeConverter, context); patterns.add(*typeConverter, context); @@ -626,17 +637,4 @@ struct ConvertStableHloToIreeInputDialects final } // namespace -void populateStableHloToLinalgOnTensorsConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - // TODO(#5809): Drop ConcatenateOp lowering in favor of the upstream version - // then remove the PatternBenefit here - patterns->add(typeConverter, context, - PatternBenefit{1000}); - - populateStableHloToLinalgConversionPatterns(context, typeConverter, patterns, - /*enablePrimitiveOps=*/false); -} - } // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp deleted file mode 100644 index c886df805ec6..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp +++ /dev/null @@ -1,2683 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO/CHLO dialects to Linalg dialect. - -#include -#include -#include - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Passes.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { - -#define GEN_PASS_DEF_CONVERTSTABLEHLOTOLINALG -#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc" - -namespace { -Value getResultValue(Operation *op) { return op->getResult(0); } - -ShapedType getHloOpResultType(Operation *op) { - return llvm::cast(getResultValue(op).getType()); -} - -/// Extracts an element from a tensor and optionally converts it to an index -/// type, based on the tensor's pre-type conversion type. -Value extractIndexFromTensor(OpBuilder &builder, Location loc, Value tensor, - ShapedType originalType, - ArrayRef tensorIndex = {}) { - Value extracted = builder.create(loc, tensor, tensorIndex); - if (extracted.getType().isIndex()) - return extracted; - return originalType.getElementType().isUnsignedInteger() - ? builder.createOrFold( - loc, builder.getIndexType(), extracted) - : builder.createOrFold( - loc, builder.getIndexType(), extracted); -} - -//===----------------------------------------------------------------------===// -// stablehlo.Einsum conversion patterns. -//===----------------------------------------------------------------------===// - -// Looks through a set of dimension that has been marked as reduction axes, -// if it is found within the set, then we set it as "reduction", otherwise -// we can label it as "parallel". -SmallVector -getEinsumLoopsAttrs(const llvm::SmallSetVector &inputInd, - const llvm::SmallSetVector &reductionDims) { - SmallVector res; - for (StringRef dim : inputInd) { - if (!reductionDims.contains(dim)) { - res.push_back(utils::IteratorType::parallel); - } else { - res.push_back(utils::IteratorType::reduction); - } - } - return res; -} - -SmallVector -extractDynamicEinsumSizes(OpBuilder &b, Location loc, Value lhs, Value rhs, - const SmallVector &lhsLoopVec, - const SmallVector &rhsLoopVec, - const SmallVector &outputLoopVec) { - SmallVector dynSizes; - for (const std::string &dimInd : outputLoopVec) { - Value dimSize; - const auto *dimIndIt = llvm::find(lhsLoopVec, dimInd); - if (dimIndIt != lhsLoopVec.end()) { - // Query from lhs vars. - auto dimIndPos = dimIndIt - lhsLoopVec.begin(); - auto lhsShape = - llvm::dyn_cast(lhs.getType()).getShape(); - if (!ShapedType::isDynamic(lhsShape[dimIndPos])) - continue; - dimSize = b.create(loc, lhs, dimIndPos); - } else { - // query from rhs vars. - dimIndIt = std::find(rhsLoopVec.begin(), rhsLoopVec.end(), dimInd); - auto dimIndPos = dimIndIt - rhsLoopVec.begin(); - auto rhsShape = - llvm::dyn_cast(rhs.getType()).getShape(); - if (!ShapedType::isDynamic(rhsShape[dimIndPos])) - continue; - dimSize = b.create(loc, rhs, dimIndPos); - } - dynSizes.push_back(dimSize); - } - return dynSizes; -} - -// Adds indices/axes that are missing from output set. -llvm::SmallSetVector -findSummationAxes(const llvm::SmallSetVector &inputSet, - const llvm::SmallSetVector &outputSet) { - llvm::SmallSetVector summationAxes; - for (StringRef ind : inputSet) { - if (!outputSet.contains(ind)) - summationAxes.insert(ind); - } - return summationAxes; -} - -// Given a 1:1 map from std::string -> affine dimension expression -// we can get the affine expression of dimensions that an -// operand will access based on the input_str of einsum_config. -// For example: -// let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2} -// for einsum_config "abc,cb->acb" -// first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2). -// second_input_operand will get umap[{"c","b"}] -> (d2, d1). -// output_operand will get umap[{"a","c","b"}] -> (d0, d2, d1). -SmallVector -getExprFromConfig(const SmallVector &loopDims, - const DenseMap &strAffineDimUmap) { - SmallVector exprs; - for (const auto &dim : loopDims) { - exprs.push_back(strAffineDimUmap.lookup(dim)); - } - return exprs; -} - -// Convert stablehlo.einsum op into linalg.generic. -// Algorithm in general 3 steps: - -// Step1) Dissect entire einsum_config to different operands -// e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}. - -// Step2) Split up the string into vector of the elements -// e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"], -// rhs:["c","d"], out:["a","b","d"]}. - -// Step3) Convert the vector into data access -// patern represented by affineMaps with affineDimensions e.g -// {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2], -// rhs:[d2,d3], out:[d0,d1,d3]}. -struct EinsumToLinalgConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::EinsumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto getRank = [](Value v) { - return llvm::cast(v.getType()).getRank(); - }; - auto einsumConfig = op.getEinsumConfig(); - - // With the assumption of binary input operand and single output - // get the inputs and output operands' indices. - // einsum_config = "lhs_loop,rhs_loop->out_loop" - std::size_t posArrow = einsumConfig.find(kArrow); - std::size_t posComma = einsumConfig.find(kComma); - - StringRef lhsLoop = einsumConfig.substr(0, posComma); - StringRef rhsLoop = einsumConfig.substr( - posComma + kComma.size(), posArrow - (posComma + kComma.size())); - StringRef outLoop = einsumConfig.substr(posArrow + kArrow.size()); - - // Check for Invalid Configs. - // 1.Check that there is only maximum 2 inputs - // 2.Check that there is only maximum 1 output - // 3.Check that there is 1 kArrow - if (rhsLoop.contains(kComma) || outLoop.contains(kComma) || - outLoop.contains(kArrow)) { - return rewriter.notifyMatchFailure(op, "Invalid einsum config!"); - } - - // Find result type, if on tensors. - auto resultTy = getTypeConverter()->convertType( - getHloOpResultType(op)); - - // Check result type compatibility. - if (!resultTy || !resultTy.getElementType().isSignlessIntOrFloat()) { - return rewriter.notifyMatchFailure(op, "Invalid result type"); - } - - // Convert the representation to vector. - SmallVector lhsEin = - getEinsumConfigAsVector(lhsLoop, getRank(adaptor.getLhs())); - SmallVector rhsEin = - getEinsumConfigAsVector(rhsLoop, getRank(adaptor.getRhs())); - SmallVector outEin = - getEinsumConfigAsVector(outLoop, resultTy.getRank()); - - if (!checkBatchHasEqualRank(lhsEin.size(), lhsLoop, rhsEin.size(), rhsLoop, - outEin.size(), outLoop)) { - return rewriter.notifyMatchFailure( - op, "Invalid elipsis('...') within einsum config!"); - } - - // Find all unique indices in the input and output. - llvm::SmallSetVector inputInd; - llvm::SmallSetVector outputInd; - - inputInd.insert(lhsEin.begin(), lhsEin.end()); - inputInd.insert(rhsEin.begin(), rhsEin.end()); - outputInd.insert(outEin.begin(), outEin.end()); - - llvm::SmallSetVector reductionAxe = - findSummationAxes(inputInd, outputInd); - - // Find input/output values and types. - Location loc = op.getLoc(); - - // Prepare init tensor for linalg.generic op. - auto dynSizes = - extractDynamicEinsumSizes(rewriter, loc, adaptor.getLhs(), - adaptor.getRhs(), lhsEin, rhsEin, outEin); - Value output = getEmptyTensor(rewriter, loc, resultTy, dynSizes); - if (!reductionAxe.empty()) { - output = fillTensorWithZeros(rewriter, loc, output); - } - - // Create indexing maps. - // Create a 1:1 map from f:strDimension -> affineDimension. - int64_t nloops = inputInd.size(); - DenseMap strAffineDimUmap; - for (auto [idx, value] : llvm::enumerate(inputInd)) { - strAffineDimUmap[value] = rewriter.getAffineDimExpr(idx); - } - - // From einsum_config of each operand in vector, generate - // the equivalent vector. - SmallVector maps; - for (const SmallVector &loopOperand : - {lhsEin, rhsEin, outEin}) { - auto exprs = getExprFromConfig(loopOperand, strAffineDimUmap); - maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext())); - } - - auto linalgOp = rewriter.create( - loc, resultTy ? resultTy : TypeRange{}, adaptor.getOperands(), output, - maps, getEinsumLoopsAttrs(inputInd, reductionAxe), - [reductionAxe](OpBuilder &b, Location nestedLoc, ValueRange args) { - Value resultVal = - b.create(nestedLoc, args[0], args[1]); - if (!reductionAxe.empty()) { - resultVal = - b.create(nestedLoc, args[2], resultVal); - } - b.create(nestedLoc, resultVal); - }, - linalg::getPrunedAttributeList(op)); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } - -private: - static constexpr StringLiteral kArrow = "->"; - static constexpr StringLiteral kComma = ","; - static constexpr StringLiteral kEllipsis = "..."; - - static bool checkBatchHasEqualRank(size_t lhsRank, StringRef lhsLoop, - size_t rhsRank, StringRef rhsLoop, - size_t outRank, StringRef outLoop); - static SmallVector getEinsumConfigAsVector(StringRef loop, - size_t operandRank); -}; - -// Convert the representation from string/vector to vector. -// i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3: -// get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"} -SmallVector -EinsumToLinalgConverter::getEinsumConfigAsVector(StringRef loop, - size_t operandRank) { - SmallVector loopDim; - size_t preElip = loop.find(kEllipsis); - bool hasElip = preElip != StringRef::npos; - if (!hasElip) - preElip = loop.size(); - // Add the dimension until the end or up to ellipsis if it exist. - for (int64_t preElipInd = 0; preElipInd < static_cast(preElip); - preElipInd++) { - loopDim.push_back(loop.substr(preElipInd, 1).str()); - } - if (!hasElip) - return loopDim; - // Case where Ellipsis presence: - size_t nonBatchRank = loop.size() - kEllipsis.size(); - size_t batchRank = operandRank - nonBatchRank; - // Add the batch dimension ("0",...,"N") where N is rank of batch into the - // loop. - for (int64_t batchInd = 0; batchInd < static_cast(batchRank); - batchInd++) { - loopDim.push_back(std::to_string(batchInd)); - } - // Add the dimension after ellipsis into the loop. - int postElip = preElip + kEllipsis.size(); - for (int64_t postElipInd = postElip; - postElipInd < static_cast(loop.size()); ++postElipInd) { - loopDim.push_back(loop.substr(postElipInd, 1).str()); - } - return loopDim; -} - -// Returns true if all operand's batch has same rank. -bool EinsumToLinalgConverter::checkBatchHasEqualRank( - size_t lhsRank, StringRef lhsLoop, size_t rhsRank, StringRef rhsLoop, - size_t outRank, StringRef outLoop) { - SmallVector batchRankVec; - if (lhsRank != lhsLoop.size()) { - size_t lhsBatchRank = lhsRank - (lhsLoop.size() - kEllipsis.size()); - batchRankVec.push_back(lhsBatchRank); - } - if (rhsRank != rhsLoop.size()) { - size_t rhsBatchRank = rhsRank - (rhsLoop.size() - kEllipsis.size()); - batchRankVec.push_back(rhsBatchRank); - } - if (outRank != outLoop.size()) { - size_t outBatchRank = outRank - (outLoop.size() - kEllipsis.size()); - batchRankVec.push_back(outBatchRank); - } - bool batchHasEqualRank = true; - - // Condition is valid if only 1 operand or less have batches. - if (batchRankVec.size() < 2) - return batchHasEqualRank; - - if (!llvm::all_equal(batchRankVec)) - return false; - - return batchHasEqualRank; -} - -/// Base class for lowering HLO operations that have one operand and one result, -/// and are semantically equivalent to a copy of the input to the output (like -/// transpose, some reshape, etc.). The derived classes need to provide a method -/// `getIndexingMaps` that returns AffineMaps for the index maps of the input -/// and the output. -template -struct DataMovementOpConverter : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - ShapedType resultType = getHloOpResultType(op); - resultType = - this->getTypeConverter()->template convertType(resultType); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - SmallVector indexingMaps = - Derived::getIndexingMaps(op, &rewriter); - if (indexingMaps.empty()) - return failure(); - - int64_t nloops = resultType.getRank(); - Location loc = op.getLoc(); - auto linalgOp = rewriter.create( - loc, - /*resultTensorTypes=*/resultType, - /*inputs=*/adaptor.getOperands().front(), - /*outputBuffers=*/ - - ValueRange{getEmptyTensorFor(rewriter, loc, resultType, op, - adaptor.getOperands())}, - indexingMaps, getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - nestedBuilder.create(loc, *args.begin()); - }, - linalg::getPrunedAttributeList(op)); - rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); - return success(); - } -}; - -/// Pattern to convert BroadcastOp to Linalg ops. -template -struct BroadcastConverter final - : DataMovementOpConverter, OpTy> { - using DataMovementOpConverter::DataMovementOpConverter; - - static SmallVector getIndexingMaps(OpTy broadcastOp, - Builder *b) { - ShapedType inputType = - llvm::cast(broadcastOp.getOperand().getType()); - unsigned inputRank = inputType.getRank(); - unsigned nloops = getHloOpResultType(broadcastOp).getRank(); - - // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to - // the input's dimensions. - unsigned numPrependedDims = llvm::size(broadcastOp.getBroadcastSizes()); - SmallVector inputDimExprs; - inputDimExprs.reserve(inputRank); - for (unsigned i = 0; i < inputRank; ++i) { - inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); - } - - AffineMap inputMap; - MLIRContext *context = b->getContext(); - if (inputDimExprs.empty()) { - // The input is a scalar, i.e. this is a scalar broadcast op. - inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); - } else { - inputMap = - AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); - } - return {inputMap, b->getMultiDimIdentityMap(nloops)}; - } -}; - -struct BroadcastOpToBroadcastConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::BroadcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - int64_t numPrependedDims = op.getBroadcastSizes().size(); - SmallVector dimensions = - llvm::to_vector(llvm::seq(0, numPrependedDims)); - - Location loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - rewriter.replaceOpWithNewOp( - op, op.getOperand(), emptyTensor, dimensions, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct HloBroadcastInDimConverter final - : DataMovementOpConverter { - using DataMovementOpConverter::DataMovementOpConverter; - - static SmallVector - getIndexingMaps(mlir::stablehlo::BroadcastInDimOp broadcastOp, Builder *b) { - ShapedType resultType = getHloOpResultType(broadcastOp); - auto operandType = cast(broadcastOp.getOperand().getType()); - unsigned nloops = resultType.getRank(); - - // The input is a scalar, i.e. this is a scalar broadcast op. - if (operandType.getRank() == 0) { - return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } - - ArrayRef operandShape = operandType.getShape(); - SmallVector dimExprs; - dimExprs.reserve(nloops); - - for (auto [idx, size] : - llvm::enumerate(broadcastOp.getBroadcastDimensions())) { - bool expansionNeeded = - operandShape[idx] == 1 && resultType.getShape()[size] != 1; - dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0) - : b->getAffineDimExpr(size)); - } - return { - AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } -}; - -Value collapseExpandingDims(PatternRewriter &rewriter, Location loc, - Value operand, SmallVector &dimensions, - llvm::function_ref isExpandingDim) { - auto operandTy = llvm::cast(operand.getType()); - - SmallVector reassociationMap; - ReassociationIndices currentIndices; - - ArrayRef operandShape = operandTy.getShape(); - SmallVector newOperandShape; - SmallVector newDimensions; - - for (auto [idx, dim] : llvm::enumerate(dimensions)) { - currentIndices.push_back(idx); - - if (!isExpandingDim(idx)) { - reassociationMap.push_back(currentIndices); - currentIndices.clear(); - newOperandShape.push_back(operandShape[idx]); - newDimensions.push_back(dim); - } - } - - if (!reassociationMap.empty()) { - reassociationMap.back().insert(reassociationMap.back().end(), - currentIndices.begin(), - currentIndices.end()); - } - - if (dimensions.size() != newDimensions.size()) { - dimensions = newDimensions; - - auto newOperandType = - RankedTensorType::get(newOperandShape, operandTy.getElementType()); - operand = rewriter.create( - loc, newOperandType, operand, reassociationMap); - } - return operand; -} - -// Insert linalg.transpose if broadcasted dimensions are not in sorted order. -// linalg.broadcast does not support implicit transpose, so the input needs to -// be explicitly transposed. -Value transposeBroadcastOperand(PatternRewriter &rewriter, Location loc, - Value operand, - SmallVector &dimensions) { - // Do not insert `transpose` is dimensions are already sorted. - if (llvm::is_sorted(dimensions)) - return operand; - - SmallVector permutation = - llvm::to_vector(llvm::seq(0, dimensions.size())); - llvm::sort(permutation, [&](int64_t lhs, int64_t rhs) { - return dimensions[lhs] < dimensions[rhs]; - }); - - auto operandTy = llvm::cast(operand.getType()); - ArrayRef operandShape = operandTy.getShape(); - SmallVector transposedOperandShape, transposedDimensions; - - for (int64_t index : permutation) { - transposedOperandShape.push_back(operandShape[index]); - transposedDimensions.push_back(dimensions[index]); - } - dimensions = transposedDimensions; - - return rewriter.create( - loc, - RankedTensorType::get(transposedOperandShape, operandTy.getElementType()), - operand, rewriter.getDenseI64ArrayAttr(permutation)); -} - -struct BroadcastInDimOpToBroadcastConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - SmallVector broadcastDimensions = - llvm::to_vector(op.getBroadcastDimensions()); - - Value operand = adaptor.getOperand(); - auto operandTy = llvm::cast(operand.getType()); - auto resultTy = - llvm::cast(typeConverter->convertType(op.getType())); - - ArrayRef operandShape = operandTy.getShape(); - ArrayRef resultShape = resultTy.getShape(); - - operand = collapseExpandingDims( - rewriter, loc, operand, broadcastDimensions, [&](int64_t i) { - return operandShape[i] == 1 && - resultShape[broadcastDimensions[i]] != 1; - }); - operand = - transposeBroadcastOperand(rewriter, loc, operand, broadcastDimensions); - - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - SmallVector addedDimensions; - for (int64_t dim : llvm::seq(0, resultTy.getRank())) { - if (!llvm::is_contained(broadcastDimensions, dim)) - addedDimensions.push_back(dim); - } - - rewriter.replaceOpWithNewOp( - op, operand, emptyTensor, addedDimensions, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -// If the input has a static shape we know exactly when the broadcast must -// expand (the dimension is 1, which also trivially expands to 1) or will never -// expand (the dimension is not 1). We can also source the information from the -// optionally provided attributes on statically known broadcasting behavior. -// This means we can lower the broadcast just as we would lower a fully static -// broadcast and go directly to `linalg.generic`. - -// This also covers the important case of broadcasting a scalar. Ideally the -// pattern (`stablehlo.constant` -> `stablehlo.dynamic_broadcast_in_dim`) should -// be converted to a tensor dialect op similar to TF's `ConstantLikeOp`. -struct HloDynamicBroadcastInDimConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value operand = adaptor.getOperand(); - auto operandType = dyn_cast(operand.getType()); - if (!operandType) - return failure(); - auto resultType = - getTypeConverter()->convertType(op.getType()); - if (!resultType) - return failure(); - - // Determine dimension expressions based on whether the dimension is - // expanding (0) or non-expanding (identity), and fail if we cannot decide - // this. - SmallVector dimExprs(operandType.getRank(), nullptr); - - // Use static type info. - auto bcastDims = - llvm::map_to_vector(op.getBroadcastDimensions(), - [](int64_t d) { return static_cast(d); }); - for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) { - if (ShapedType::isDynamic(dim)) - continue; - - bool isExpanding = dim == 1; - dimExprs[idx] = isExpanding ? rewriter.getAffineConstantExpr(0) - : rewriter.getAffineDimExpr(bcastDims[idx]); - } - - // Use annotated expansion behavior, if available. - if (auto dims = op.getKnownExpandingDimensions()) { - for (int i : *dims) { - dimExprs[i] = rewriter.getAffineConstantExpr(0); - } - } - if (auto dims = op.getKnownNonexpandingDimensions()) { - for (int i : *dims) { - dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]); - } - } - - // Fail if unknown expansion behavior remains. - if (!llvm::all_of(dimExprs, [](AffineExpr expr) { return expr; })) - return failure(); - - // Materialize `linalg.generic` op. - Location loc = op.getLoc(); - int64_t nloops = resultType.getRank(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - rewriter.replaceOpWithNewOp( - op, TypeRange{emptyTensor.getType()}, ValueRange{operand}, - /*outputBuffers=*/ValueRange{emptyTensor}, - llvm::ArrayRef({AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, - dimExprs, rewriter.getContext()), - rewriter.getMultiDimIdentityMap(nloops)}), - getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - nestedBuilder.create(loc, *args.begin()); - }, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct DynamicBroadcastInDimOpToBroadcastConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - Value operand = adaptor.getOperand(); - auto operandTy = llvm::dyn_cast(operand.getType()); - if (!operandTy) - return failure(); - auto resultTy = - getTypeConverter()->convertType(op.getType()); - if (!resultTy) - return failure(); - - SmallVector broadcastDimensions = - llvm::to_vector(op.getBroadcastDimensions()); - - SmallVector> expansionBehavior( - broadcastDimensions.size()); - - // Use static type info. - for (auto [idx, dim] : llvm::enumerate(operandTy.getShape())) { - if (ShapedType::isDynamic(dim)) - continue; - expansionBehavior[idx] = (dim == 1); - } - - // Use annotated expansion behavior, if available. - if (op.getKnownExpandingDimensions()) { - auto dims = op.getKnownExpandingDimensions().value(); - for (int it : dims) { - expansionBehavior[it] = true; - } - } - if (op.getKnownNonexpandingDimensions()) { - auto dims = op.getKnownNonexpandingDimensions().value(); - for (int it : dims) { - expansionBehavior[it] = false; - } - } - - // Fail if unknown expansion behavior remains. - if (llvm::any_of(expansionBehavior, [](auto v) { return !v.has_value(); })) - return failure(); - - auto isExpandingDim = [&](int64_t i) { - return expansionBehavior[i].value(); - }; - - // Use attribute information to insert 1s into operand type. - operand = getBroadcastOperand(rewriter, loc, operand, isExpandingDim); - - auto broadcastResultTy = getBroadcastResultType( - operand, resultTy, broadcastDimensions, isExpandingDim); - - operand = collapseExpandingDims(rewriter, loc, operand, broadcastDimensions, - isExpandingDim); - operand = - transposeBroadcastOperand(rewriter, loc, operand, broadcastDimensions); - - Value emptyTensor = getEmptyTensorFor(rewriter, loc, broadcastResultTy, op, - adaptor.getOperands()); - - SmallVector addedDimensions; - for (int64_t dim : llvm::seq(0, resultTy.getRank())) { - if (!llvm::is_contained(broadcastDimensions, dim)) - addedDimensions.push_back(dim); - } - - Value result = rewriter - .create( - loc, operand, emptyTensor, addedDimensions, - linalg::getPrunedAttributeList(op)) - .getResults()[0]; - - if (resultTy != broadcastResultTy) { - result = rewriter.create(loc, resultTy, result); - } - - rewriter.replaceOp(op, result); - return success(); - } - -private: - static Value - getBroadcastOperand(PatternRewriter &rewriter, Location loc, Value operand, - llvm::function_ref isExpandingDim) { - auto operandTy = llvm::dyn_cast(operand.getType()); - - SmallVector updatedOperandShape = - llvm::to_vector(operandTy.getShape()); - for (auto [idx, dim] : llvm::enumerate(updatedOperandShape)) { - if (isExpandingDim(idx)) - dim = 1; - } - - auto updatedOperandTy = - RankedTensorType::get(updatedOperandShape, operandTy.getElementType()); - - if (updatedOperandTy != operandTy) { - operand = rewriter.create(loc, updatedOperandTy, operand); - } - - return operand; - } - - static ShapedType - getBroadcastResultType(Value operand, RankedTensorType resultTy, - ArrayRef dimensions, - llvm::function_ref isExpandingDim) { - auto operandShape = - llvm::cast(operand.getType()).getShape(); - auto broadcastResultShape = llvm::to_vector(resultTy.getShape()); - - for (auto [operandIndex, resultIndex] : llvm::enumerate(dimensions)) { - if (isExpandingDim(operandIndex)) - continue; - broadcastResultShape[resultIndex] = operandShape[operandIndex]; - } - - return RankedTensorType::get(broadcastResultShape, - resultTy.getElementType()); - } -}; - -template -struct TransposeConverter final - : DataMovementOpConverter, OpTy> { - using DataMovementOpConverter, - OpTy>::DataMovementOpConverter; - - static SmallVector getIndexingMaps(OpTy op, Builder *b) { - auto resultType = llvm::cast(getHloOpResultType(op)); - int64_t nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.resize(resultType.getRank()); - for (auto [idx, value] : llvm::enumerate(op.getPermutation())) { - inputExprs[value] = b->getAffineDimExpr(idx); - } - return { - AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } -}; - -struct TransposeOpToTransposeConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::TransposeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - Location loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - auto permutation = - dyn_cast_or_null(op.getPermutationAttr()); - - rewriter.replaceOpWithNewOp( - op, adaptor.getOperand(), emptyTensor, permutation, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct BitcastConvertConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::BitcastConvertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - auto inputType = - llvm::cast(adaptor.getOperand().getType()); - auto outputType = - getTypeConverter()->convertType(op.getType()); - if (!outputType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - Location loc = op.getLoc(); - - // Fallback to pointwise conversion if the tensor dimensions are not - // changing. - if (inputType.getRank() == outputType.getRank()) { - return failure(); - } - - auto inputBitWidth = inputType.getElementType().getIntOrFloatBitWidth(); - auto outputBitWidth = outputType.getElementType().getIntOrFloatBitWidth(); - - auto maxRank = std::max(inputType.getRank(), outputType.getRank()); - auto identityMap = - AffineMap::getMultiDimIdentityMap(maxRank, rewriter.getContext()); - AffineMap indexingMaps[] = { - AffineMap::get( - /*dimCount=*/maxRank, /*symbolCount=*/0, - identityMap.getResults().take_front(inputType.getRank()), - rewriter.getContext()), - AffineMap::get( - /*dimCount=*/maxRank, /*symbolCount=*/0, - identityMap.getResults().take_front(outputType.getRank()), - rewriter.getContext())}; - - Value output = - getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); - bool isExpansion = inputBitWidth > outputBitWidth; - bool isContraction = inputBitWidth < outputBitWidth; - // When combining values we start with a 0 and merge bits into it. - if (isContraction) { - output = fillTensorWithZeros(rewriter, loc, output); - } - - rewriter.replaceOpWithNewOp( - op, outputType, adaptor.getOperand(), output, indexingMaps, - getParallelAndReductionIterators(maxRank, isContraction ? 1 : 0), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - auto inIntType = nestedBuilder.getIntegerType(inputBitWidth); - auto outIntType = nestedBuilder.getIntegerType(outputBitWidth); - Value innerResult = args.front(); - if (isExpansion) { - // Expand a big value into multiple small values with shifts. - auto iotaIndex = - nestedBuilder.create(nestedLoc, maxRank - 1); - auto iota = nestedBuilder.create( - nestedLoc, inIntType, iotaIndex); - - auto width = nestedBuilder.create( - nestedLoc, - nestedBuilder.getIntegerAttr(inIntType, outputBitWidth)); - auto shiftWidth = - nestedBuilder.create(nestedLoc, iota, width); - Value inputCasted = nestedBuilder.create( - nestedLoc, inIntType, args.front()); - Value shifted = nestedBuilder.create( - nestedLoc, inputCasted, shiftWidth); - innerResult = nestedBuilder.create( - nestedLoc, outIntType, shifted); - } else if (isContraction) { - // Combine multiple small values into one big value. - auto iotaIndex = - nestedBuilder.create(nestedLoc, maxRank - 1); - auto iota = nestedBuilder.create( - nestedLoc, outIntType, iotaIndex); - - auto width = nestedBuilder.create( - nestedLoc, - nestedBuilder.getIntegerAttr(outIntType, inputBitWidth)); - auto shiftWidth = - nestedBuilder.create(nestedLoc, iota, width); - Value inputCasted = nestedBuilder.create( - nestedLoc, inIntType, args.front()); - Value inputExt = nestedBuilder.create( - nestedLoc, outIntType, inputCasted); - Value shifted = nestedBuilder.create( - nestedLoc, inputExt, shiftWidth); - Value accumulatorCasted = nestedBuilder.create( - nestedLoc, outIntType, args.back()); - innerResult = nestedBuilder.create( - nestedLoc, outIntType, shifted, accumulatorCasted); - } - innerResult = nestedBuilder.create( - nestedLoc, outputType.getElementType(), innerResult); - nestedBuilder.create(nestedLoc, innerResult); - }, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -// Lowers stablehlo.RealDynamicSliceOp to tensor.extract_slice and other -// arith/tensor dialect ops. -struct RealDynamicSliceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - // Computes size of a slice as - // size = ceil((limit - start)/stride) - static Value computeSize(Location loc, Value start, Value limit, Value stride, - ConversionPatternRewriter &b) { - Value delta = b.create(loc, limit, start); - Value ret = b.create(loc, delta, stride); - if (ret.getType().isIndex()) - return ret; - return b.create(loc, b.getIndexType(), ret); - } - - LogicalResult - matchAndRewrite(mlir::stablehlo::RealDynamicSliceOp realDynamicSliceOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = realDynamicSliceOp.getLoc(); - auto argType = llvm::dyn_cast(adaptor.getOperand().getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(realDynamicSliceOp, - "require known-rank args"); - } - - Type dimElementType = getElementTypeOrSelf(adaptor.getStartIndices()); - if (getElementTypeOrSelf(adaptor.getLimitIndices()) != dimElementType || - getElementTypeOrSelf(adaptor.getStrides()) != dimElementType) { - return rewriter.notifyMatchFailure( - realDynamicSliceOp, - "requires same element type for all dimension specification"); - } - Type arithType = - dimElementType.isIndex() ? rewriter.getI64Type() : dimElementType; - Type indexType = rewriter.getIndexType(); - - auto resultType = llvm::cast( - this->typeConverter->convertType(realDynamicSliceOp.getType())); - Value zero = rewriter.create(loc, 0); - SmallVector offsets, sizes, strides; - SmallVector clampType(3, arithType); - for (auto i : llvm::seq(0, argType.getRank())) { - Value dim = rewriter.create(loc, i); - Value start = rewriter.create( - loc, adaptor.getStartIndices(), dim); - Value limit = rewriter.create( - loc, adaptor.getLimitIndices(), dim); - Value stride = - rewriter.create(loc, adaptor.getStrides(), dim); - - // Compute i-th dimension size of the result : size[i]. - // If the i-th dimension of the result type is known, we go ahead with it - // else we compute it using limit, start and stride values. - int64_t resultDimSize = resultType.getDimSize(i); - Value size = - ShapedType::isDynamic(resultDimSize) - ? computeSize(loc, start, limit, stride, rewriter) - : rewriter.create(loc, resultDimSize); - - // We can now convert start to index. - if (!start.getType().isIndex()) - start = rewriter.create( - loc, rewriter.getIndexType(), start); - - // Fetch i-th dimension size of the operand and calculate upper bound as - // ub = operand_dim[i] - size[i] - Value operandDimSize = - rewriter.createOrFold(loc, adaptor.getOperand(), dim); - Value upperBound = - rewriter.createOrFold(loc, operandDimSize, size); - - // We clamp the start_index to keep it bounded as - // 0 <= start_index[i] <= ub - // Clamp does not support index type, so cast to integer type. - start = rewriter.create(loc, start, zero); - start = rewriter.create(loc, start, upperBound); - - offsets.push_back(start); - if (ShapedType::isDynamic(resultDimSize)) - sizes.push_back(size); - else - sizes.push_back(IntegerAttr::get(indexType, resultDimSize)); - - if (!stride.getType().isIndex()) - stride = - rewriter.createOrFold(loc, indexType, stride); - strides.push_back(stride); - } - - rewriter.replaceOpWithNewOp( - realDynamicSliceOp, resultType, adaptor.getOperand(), offsets, sizes, - strides); - return success(); - } -}; - -// Converts reshape ops that can be proven to be either a collapse of dimensions -// or expansion of dimensions of the operand. -struct ReshapeOpConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReshapeOp reshapeOp, - mlir::stablehlo::ReshapeOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(reshapeOp))) - return failure(); - Value operand = adaptor.getOperand(); - auto operandType = llvm::cast(operand.getType()); - Type elemType = operandType.getElementType(); - auto resultType = llvm::cast(reshapeOp.getType()); - - if (!resultType.hasStaticShape()) - return failure(); - - // If any of the output dimensions is 0, the tensor has no elements. In that - // case, we can just replace the reshape with an empty op. - if (llvm::is_contained(resultType.getShape(), 0)) { - rewriter.replaceOpWithNewOp( - reshapeOp, resultType.getShape(), elemType); - return success(); - } - - resultType = getTypeConverter()->convertType(resultType); - if (!resultType) - return rewriter.notifyMatchFailure(reshapeOp, "type conversion failed"); - - // Special case where the result is a scalar. - if (resultType.getRank() == 0 && !operandType.hasStaticShape()) { - // This means all dimensions of the operand need to be 1. We add a cast to - // cast the dynamic dimensions to 1. - auto staticType = RankedTensorType::get( - llvm::SmallVector(operandType.getRank(), 1), elemType); - operand = rewriter.create(reshapeOp.getLoc(), staticType, - operand); - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, operand, ArrayRef{}); - return success(); - } - - // Compute the reassociation maps for the linalg operation. This will - // succeed if the reshape can be done with a single expand_shape or - // collapse_shape. - if (std::optional> reassociationMap = - getReassociationIndicesForReshape(operandType, resultType)) { - if (resultType.getRank() < operandType.getRank()) { - // We have found a working reassociation map. If the operand is dynamic, - // we first need to cast all unknown dimensions in the input that get - // collapsed to a static-sized dimension in the output, to 1. - SmallVector shape(operandType.getShape().begin(), - operandType.getShape().end()); - for (auto [idx, dims] : llvm::enumerate(*reassociationMap)) { - // If the result dim is dynamic, we do not mind dynamic entries in the - // source. - if (resultType.isDynamicDim(idx)) - continue; - for (auto targetDim : dims) { - if (ShapedType::isDynamic(shape[targetDim])) - shape[targetDim] = 1; - } - } - // Insert a cast if types are not the same (ignoring sparse encoding). - auto enc = sparse_tensor::getSparseTensorEncoding(operandType); - auto newOperandType = RankedTensorType::get(shape, elemType, enc); - if (newOperandType != operandType) { - operand = rewriter.create(reshapeOp.getLoc(), - newOperandType, operand); - } - // Generate collapse operation. - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, operand, *reassociationMap); - } else { - // Generate expand operation. - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, operand, *reassociationMap); - } - return success(); - } - - Value collapsedOp = operand; - Location loc = reshapeOp.getLoc(); - auto getIdentityExprs = [&rewriter](int64_t n) { - SmallVector exprs; - for (int i = 0; i < n; ++i) - exprs.push_back(rewriter.getAffineDimExpr(i)); - return exprs; - }; - // Otherwise, we need to first reduce all source dimensions into one and - // then expand to the destination dimensions. If there is only a single - // source dimension, the reduce step can be skipped. TensorCollapseShape - // expects a different rank of operand and result. - if (operandType.getRank() != 1) { - SmallVector collapsingMap = { - // Use operand_type here because we need to collapse all operands - // dimensions. - getIdentityExprs(operandType.getRank())}; - - collapsedOp = - rewriter.create(loc, operand, collapsingMap); - } - // Cast to a known static type if the input has dynamic dimensions. - int64_t totalElems = resultType.getNumElements(); - auto collapsedType = RankedTensorType::get({totalElems}, elemType); - collapsedOp = - rewriter.create(loc, collapsedType, collapsedOp); - if (resultType.getRank() == 1) { - rewriter.replaceOp(reshapeOp, collapsedOp); - } else { - SmallVector expandingMap = { - // Use resultType here because we need to expand to all result - // dimensions. - getIdentityExprs(resultType.getRank())}; - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, collapsedOp, expandingMap); - } - return success(); - } -}; - -template -struct IotaConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using Adaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy iotaOp, Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ShapedType resultShapedType = getHloOpResultType(iotaOp); - if (!resultShapedType) - return failure(); - - resultShapedType = - this->getTypeConverter()->template convertType( - resultShapedType); - if (!resultShapedType) - return rewriter.notifyMatchFailure(iotaOp, "type conversion failed"); - - Type resultElementType = resultShapedType.getElementType(); - - // Construct the indexing maps needed for linalg.generic ops. - unsigned nloops = resultShapedType.getRank(); - - Location loc = iotaOp.getLoc(); - auto linalgOp = rewriter.create( - loc, - /*resultTensorTypes=*/ - ArrayRef{resultShapedType}, - /*inputs=*/ValueRange{}, - /*outputBuffers=*/ - - ValueRange{getEmptyTensorFor(rewriter, loc, resultShapedType, iotaOp, - adaptor.getOperands())}, - llvm::ArrayRef(rewriter.getMultiDimIdentityMap(nloops)), - getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange /*args*/) { - Value indexOp = nestedBuilder.create( - nestedLoc, iotaOp.getIotaDimension()); - Type unwrappedResultElementType = resultElementType; - if (auto complexType = - llvm::dyn_cast(unwrappedResultElementType)) - unwrappedResultElementType = complexType.getElementType(); - Value castOp = nestedBuilder.create( - nestedLoc, - nestedBuilder.getIntegerType( - unwrappedResultElementType.getIntOrFloatBitWidth()), - indexOp); - castOp = mlir::stablehlo::StableHloOpToStdScalarOp::mapOpOfType< - mlir::stablehlo::ConvertOp>(nestedLoc, resultElementType, - castOp.getType(), {castOp}, - &nestedBuilder); - nestedBuilder.create(nestedLoc, castOp); - }, - linalg::getPrunedAttributeList(iotaOp)); - rewriter.replaceOp(iotaOp, linalgOp.getResultTensors()); - return success(); - } -}; - -template -struct IotaToMapConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using Adaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy iotaOp, Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ShapedType resultTy = getHloOpResultType(iotaOp); - if (!resultTy) - return failure(); - - resultTy = - this->getTypeConverter()->template convertType(resultTy); - if (!resultTy) - return rewriter.notifyMatchFailure(iotaOp, "type conversion failed"); - - Location loc = iotaOp.getLoc(); - Value empty = getEmptyTensorFor(rewriter, loc, resultTy, iotaOp, - adaptor.getOperands()); - - auto linalgOp = rewriter.create( - loc, ValueRange{}, empty, - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange /*args*/) { - Value index = nestedBuilder.create( - nestedLoc, iotaOp.getIotaDimension()); - index = nestedBuilder.create( - nestedLoc, nestedBuilder.getI64Type(), index); - Value result = mlir::stablehlo::StableHloOpToStdScalarOp::mapOpOfType< - mlir::stablehlo::ConvertOp>(nestedLoc, resultTy.getElementType(), - index.getType(), {ValueRange{index}}, - &nestedBuilder); - nestedBuilder.create(nestedLoc, ValueRange{result}); - }, - linalg::getPrunedAttributeList(iotaOp)); - rewriter.replaceOp(iotaOp, linalgOp.getResult()); - return success(); - } -}; - -/// Converts stablehlo.concatenate operation to a linalg.generic op. -struct ConcatenateConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConcatenateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Shortcut the one-operand case, simplifies code below. - if (adaptor.getOperands().size() == 1) { - rewriter.replaceOp(op, adaptor.getOperands()[0]); - return success(); - } - - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - uint64_t dim = op.getDimension(); - Location loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - - // Allocate the output tensor with tensor.empty. - Value result = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - - // Generate a generic op to gather the elements of the concatenate. This is - // awkward standalone but allows fusion with other generic ops. - int64_t nloops = resultType.getRank(); - rewriter.replaceOpWithNewOp( - op, - /*resultTensorTypes=*/resultType, - /*inputs=*/ValueRange{}, /*outputBuffers=*/result, - llvm::ArrayRef(rewriter.getMultiDimIdentityMap(nloops)), - getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location loc, ValueRange) { - OpBuilder b = nestedBuilder; - Value concatDimSize = zero; - Value result; - - SmallVector extractIndices; - extractIndices.reserve(nloops); - for (int64_t i = 0; i < nloops; i++) { - extractIndices.push_back(b.create(loc, i)); - } - - Value indexOp = b.create(loc, dim); - for (auto [idx, arg] : llvm::enumerate(adaptor.getOperands())) { - Value newConcatDimSize; - scf::IfOp ifOp; - if (idx + 1 != adaptor.getOperands().size()) { - // Calculate how far along we have iterated along the concatenate - // dimension. That way we can tell which input to select. - newConcatDimSize = b.create( - loc, concatDimSize, b.create(loc, arg, dim)); - Value cmp = b.create(loc, rewriter.getI1Type(), - arith::CmpIPredicate::ult, - indexOp, newConcatDimSize); - ifOp = b.create(loc, resultType.getElementType(), cmp, - true); - if (result) { - b.create(loc, ifOp->getResults()[0]); - } else { - result = ifOp->getResults()[0]; - } - - b = ifOp.getThenBodyBuilder(b.getListener()); - } - - // Now adjust the index for the concatenated dimension to fit into - // the selected tensor and do an extract at that position. - extractIndices[dim] = - b.create(loc, indexOp, concatDimSize); - Value extract = - b.create(loc, arg, extractIndices); - b.create(loc, extract); - - if (ifOp) { - b = ifOp.getElseBodyBuilder(b.getListener()); - concatDimSize = newConcatDimSize; - } - } - nestedBuilder.create(loc, result); - }, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct ConstConverterTensor final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConstantOp constOp, OpAdaptor /*adaptor*/, - ConversionPatternRewriter &rewriter) const override { - auto replacementType = - getTypeConverter()->convertType(constOp.getType()); - if (!replacementType) - return rewriter.notifyMatchFailure(constOp, "type conversion failed"); - - ElementsAttr replacementAttr = constOp.getValue(); - if (replacementType == constOp.getType()) { - rewriter.replaceOpWithNewOp(constOp, replacementType, - replacementAttr); - return success(); - } else { - auto denseAttr = dyn_cast(constOp.getValue()); - if (!denseAttr) { - return rewriter.notifyMatchFailure( - constOp, - "DenseElementsAttr cast failed (only DenseElementsAttr supported)"); - } - // Signedness conversion. - // TODO(#15442): Add generic mapping utility, so we aren't limited to - // supporting only DenseElementsAttr. - replacementAttr = denseAttr.mapValues(replacementType.getElementType(), - [](const APInt &i) { return i; }); - rewriter.replaceOpWithNewOp(constOp, replacementType, - replacementAttr); - return success(); - } - } -}; - -// TODO(b/156787842): Support the lowering for dynamic shapes. -struct ReverseConverter final - : DataMovementOpConverter { - using DataMovementOpConverter::DataMovementOpConverter; - - static SmallVector - getIndexingMaps(mlir::stablehlo::ReverseOp op, Builder *b) { - auto resultType = llvm::cast(getHloOpResultType(op)); - int64_t nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.reserve(nloops); - for (int64_t i = 0; i < nloops; ++i) - inputExprs.push_back(b->getAffineDimExpr(i)); - for (int i : op.getDimensions()) { - if (resultType.isDynamicDim(i)) - return {}; - int n = resultType.getShape()[i]; - inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; - } - return { - AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } -}; - -struct SliceConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::SliceOp sliceOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto argType = - llvm::dyn_cast(adaptor.getOperands()[0].getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args"); - } - - SmallVector offsets, sizes, strides; - auto startIndices = sliceOp.getStartIndices(); - auto limitIndices = sliceOp.getLimitIndices(); - auto sliceStrides = sliceOp.getStrides(); - - for (int64_t i = 0, e = argType.getRank(); i < e; ++i) { - int64_t start = startIndices[i]; - int64_t limit = limitIndices[i]; - int64_t stride = sliceStrides[i]; - offsets.push_back(rewriter.getI64IntegerAttr(start)); - // Say that there are k elements in total, we have condition: - // start + (k - 1) * strides <= limit - 1 - // -> - // k <= (limit - 1 - start + strides) / strides - sizes.push_back( - rewriter.getI64IntegerAttr((limit - 1 - start + stride) / stride)); - strides.push_back(rewriter.getI64IntegerAttr(stride)); - } - rewriter.replaceOpWithNewOp( - sliceOp, adaptor.getOperands()[0], offsets, sizes, strides); - return success(); - } -}; - -struct DynamicSliceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicSliceOp dynamicSliceOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = dynamicSliceOp.getLoc(); - auto argType = llvm::dyn_cast(adaptor.getOperand().getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(dynamicSliceOp, - "require known-rank args"); - } - - auto resultType = getTypeConverter()->convertType( - dynamicSliceOp.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(dynamicSliceOp, - "type conversion failed"); - - SmallVector startIndices, sizes; - auto originalStartIndexType = llvm::cast( - dynamicSliceOp.getStartIndices().front().getType()); - for (auto [idx, start, size] : llvm::enumerate( - adaptor.getStartIndices(), dynamicSliceOp.getSliceSizes())) { - sizes.push_back(rewriter.getI64IntegerAttr(size)); - - // By stablehlo.DynamicSlice definition: - // `start_indices[i] = clamp(start_indices[i], - // 0, operand.dimension_size[i] - size_indices[i])` - Value startIndex = - extractIndexFromTensor(rewriter, loc, start, originalStartIndexType); - - Value mn = rewriter.create(loc, 0); - - Value mx = - rewriter.createOrFold(loc, adaptor.getOperand(), idx); - mx = rewriter.createOrFold( - loc, mx, rewriter.create(loc, size)); - - startIndex = rewriter.create(loc, startIndex, mn); - startIndex = rewriter.create(loc, startIndex, mx); - - startIndices.push_back(startIndex); - } - - int64_t rank = argType.getRank(); - SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); - - rewriter.replaceOpWithNewOp( - dynamicSliceOp, resultType, adaptor.getOperand(), startIndices, sizes, - strides); - return success(); - } -}; - -struct DynamicUpdateSliceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto operandType = - llvm::dyn_cast(adaptor.getOperand().getType()); - if (!operandType || !operandType.hasStaticShape()) { - return rewriter.notifyMatchFailure( - op, "require static ranked type for operand"); - } - - auto updateType = - llvm::dyn_cast(adaptor.getUpdate().getType()); - if (!updateType || !updateType.hasStaticShape()) { - return rewriter.notifyMatchFailure( - op, "require static ranked type for operand"); - } - - // We do not have to clamp sizes because the semantic of `update` - // guarantees that it is always in the bounds. See - // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice - SmallVector sizes; - for (int64_t size : updateType.getShape()) { - sizes.push_back(rewriter.getIndexAttr(size)); - } - - SmallVector startIndices; - Value zero = rewriter.create(loc, 0); - for (auto [idx, start] : llvm::enumerate(adaptor.getStartIndices())) { - // By stablehlo.DynamicUpdateSlice definition: - // `start_indices[i] = clamp(start_indices[i], - // 0, operand.dimension_size[i] - update.dimension_size[i])` - Value startIndex = extractIndexFromTensor( - rewriter, loc, start, - cast(op.getStartIndices()[idx].getType())); - Value ub = rewriter.create( - loc, operandType.getDimSize(idx) - updateType.getDimSize(idx)); - - startIndex = rewriter.create(loc, startIndex, zero); - startIndex = rewriter.create(loc, startIndex, ub); - startIndices.push_back(startIndex); - } - - int64_t rank = operandType.getRank(); - SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); - rewriter.replaceOpWithNewOp( - op, adaptor.getUpdate(), adaptor.getOperand(), startIndices, sizes, - strides); - return success(); - } -}; - -struct MapOpToGenericConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::MapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - assert(op.getDimensions().size() == resultType.getRank() && - "Expected a pointwise map"); - - Location loc = op.getLoc(); - Value output = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - SmallVector indexingMaps( - op.getNumOperands() + 1, - rewriter.getMultiDimIdentityMap(resultType.getRank())); - - auto linalgOp = rewriter.create( - loc, resultType, adaptor.getOperands(), output, indexingMaps, - getNParallelLoopsAttrs(resultType.getRank()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. We scalarize the operands and add a - // scalar operand representing the output tensor. - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getComputation(), region, region.end()); - TypeConverter::SignatureConversion signatureConverter(op.getNumOperands() + - 1); - - for (auto [idx, operand] : llvm::enumerate(op.getOperands())) { - Type convertedTy = getTypeConverter()->convertType( - cast(operand.getType()).getElementType()); - if (!convertedTy) - return rewriter.notifyMatchFailure(op, - "operand type conversion failed"); - - signatureConverter.addInputs(idx, convertedTy); - } - signatureConverter.addInputs(resultType.getElementType()); - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct MapOpToMapConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::MapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - assert(op.getDimensions().size() == resultType.getRank() && - "Expected a pointwise map"); - - Location loc = op.getLoc(); - Value operand0 = adaptor.getOperands()[0]; - SmallVector coercedOperands = {operand0}; - for (Value operand : llvm::drop_begin(adaptor.getOperands(), 1)) { - coercedOperands.push_back(coerceTensorShape( - rewriter, loc, cast>(operand), - cast(operand0.getType()))); - } - Value output = rewriter.create( - loc, tensor::getMixedSizes(rewriter, loc, operand0), - resultType.getElementType()); - - auto linalgOp = rewriter.create( - loc, coercedOperands, output, - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. We scalarize the operands and add a - // scalar operand representing the output tensor. - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getComputation(), region, region.end()); - TypeConverter::SignatureConversion signatureConverter(op.getNumOperands()); - - for (auto [idx, operand] : llvm::enumerate(op.getOperands())) { - Type convertedTy = getTypeConverter()->convertType( - cast(operand.getType()).getElementType()); - if (!convertedTy) - return rewriter.notifyMatchFailure(op, - "operand type conversion failed"); - signatureConverter.addInputs(idx, convertedTy); - } - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - auto result = rewriter.createOrFold(loc, resultType, - linalgOp.getResults()); - rewriter.replaceOp(op, result); - return success(); - } -}; - -/// This lowering encompasses the full range of the Gather operation and -/// therefore is very general and just loops over the output and calculate the -/// corresponding input index. It follows the explanation at -/// https://www.tensorflow.org/xla/operation_semantics#gather. The compiler -/// should be able to optimize that a bit, but in order to get efficient -/// lowerings, special-cases of gather should be extracted in separate -/// lowerings, and ideally encapsulated as separate ops or canonicalization -/// patterns. -struct GatherConversion final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::GatherOp gatherOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = gatherOp.getLoc(); - - Value startIndices = adaptor.getStartIndices(); - Value operand = adaptor.getOperand(); - - auto resultType = - getTypeConverter()->convertType(gatherOp.getType()); - RankedTensorType startIndicesType = - dyn_cast(startIndices.getType()); - // We could actually deal with an unranked result by inferring the result - // rank, but the current reifyReturnTypes doesn't support unranked either. - if (!resultType || !startIndicesType) { - return rewriter.notifyMatchFailure(gatherOp, - "unranked start indices or result"); - } - - int64_t resultRank = resultType.getRank(); - // slice_sizes has to have the same size as operand.rank, and doing it this - // way permits an unranked operand. - int64_t operandRank = gatherOp.getSliceSizes().size(); - - int64_t indexVectorDim = gatherOp.getDimensionNumbers().getIndexVectorDim(); - - ArrayRef offsetDims = - gatherOp.getDimensionNumbers().getOffsetDims(); - ArrayRef collapsedSliceDims = - gatherOp.getDimensionNumbers().getCollapsedSliceDims(); - ArrayRef startIndexMap = - gatherOp.getDimensionNumbers().getStartIndexMap(); - - // We'll need these later and creating them on demand we end up with - // duplicates, which also makes lit tests really hard to write. - SmallVector constants; - for (int64_t i = 0, e = std::max({resultRank, operandRank, int64_t{2}}); - i < e; ++i) { - constants.push_back( - rewriter.create(loc, rewriter.getIndexAttr(i))); - } - - Value emptyOp = getEmptyTensorFor(rewriter, loc, resultType, gatherOp, - adaptor.getOperands()); - - ValueRange ins; - SmallVector indexingMaps( - {rewriter.getMultiDimIdentityMap(resultRank)}); - auto linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/resultType, - /*inputs=*/ins, - /*outputs=*/emptyOp, indexingMaps, getNParallelLoopsAttrs(resultRank), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(gatherOp)); - - // Now populate the linalg generic region - Region ®ion = linalgOp.getRegion(); - Block *block = rewriter.createBlock(®ion, region.end()); - block->addArguments(resultType.getElementType(), loc); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(block); - - // Dimensions in the result that aren't offset dimensions are called batch. - SmallVector batchDims; - for (int64_t dim = 0; dim < resultRank; ++dim) { - if (!llvm::is_contained(offsetDims, dim)) { - batchDims.push_back(dim); - } - } - - // Same as with the constants. Creating these all up front is easier than - // potentially getting duplicates later. - SmallVector linalgIndices; - for (int64_t i = 0; i < resultRank; ++i) { - linalgIndices.push_back(rewriter.create(loc, i)); - } - - // Now the complicated part. For a given output dimension we build up an - // index into the input. It's composed of two parts: the index coming from - // start_indices, and the offset from that index along the offset - // dimensions. Everything includes dimension shuffling and remapping as well - // because of the way gather is defined to allow for any-layout input by - // adding more attributes. - - // The base gather index (`G` in the documentation) points to a place in - // start_indices along the batch dimensions. - SmallVector gatherIndex; - for (int64_t dim : batchDims) { - gatherIndex.push_back(linalgIndices[dim]); - } - - SmallVector indexFromStartIndices; - for (size_t i = 0, e = startIndexMap.size(); i != e; ++i) { - // The index along the index_vector dimension of start_indices varies. - // Basically indexFromStartIndices indexes into a "row" along - // index_vector_dim, where the row is selected by the current output - // index. - // But if index_vector_dim is equal to start_indices.rank, then - // start_indices gets a trailing 1 dimension added. So the row we're - // extracting always has length 1 and the index into it is always 0, so we - // just use the gather index directly - SmallVector gCombine(gatherIndex); - if (indexVectorDim != startIndicesType.getRank()) { - assert(indexVectorDim <= static_cast(gCombine.size())); - gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]); - } - - indexFromStartIndices.push_back(extractIndexFromTensor( - rewriter, loc, startIndices, gatherOp.getStartIndices().getType(), - gCombine)); - } - - // But then start indices are shuffled by the start index map. To make a - // full index into the operand, all missing indices are zeroes. - SmallVector remappedIndexFromIndices(operandRank, constants[0]); - for (auto [idx, value] : llvm::enumerate(startIndexMap)) { - remappedIndexFromIndices[value] = indexFromStartIndices[idx]; - } - - // Now we construct the index based on the offset. First we need to remap - // the offset dimensions by dropping the collapsed indices. - SmallVector remappedOffsetDims; - for (int64_t i = 0; i < operandRank; ++i) { - if (!llvm::is_contained(collapsedSliceDims, i)) { - remappedOffsetDims.push_back(static_cast(i)); - } - } - - assert(remappedOffsetDims.size() == offsetDims.size()); - - // Clamp out of bounds indices. - for (int i = 0, operandIndexDim = 0; i < operandRank; ++i) { - // Compute the size of the output shape dimension corresponding to this - // index dimension. If it's collapsed set it to 1. - Value outputDimSize = constants[1]; - if (!llvm::is_contained(collapsedSliceDims, i)) { - outputDimSize = rewriter.createOrFold( - loc, emptyOp, offsetDims[operandIndexDim++]); - } - - // If this is a skipped dimension, we're done and don't have to clamp. - if (remappedIndexFromIndices[i] == constants[0]) - continue; - - Value operandDimSize = - rewriter.createOrFold(loc, operand, i); - Value largestValidIndex = rewriter.createOrFold( - loc, operandDimSize, outputDimSize); - - // Clamp indices to [0, i, operand_dim-output_dim]. - Value clamp = rewriter.create( - loc, - rewriter.create(loc, constants[0], - remappedIndexFromIndices[i]), - largestValidIndex); - remappedIndexFromIndices[i] = clamp; - } - - // For the (remapped) offset dimensions, the index is the current index in - // the output. As before this is expanded to a full index into the operand - // by using zeros for the missing indices. - SmallVector indexFromOffset(operandRank, constants[0]); - for (auto [remappedOffsetDim, offsetDim] : - llvm::zip_equal(remappedOffsetDims, offsetDims)) { - indexFromOffset[remappedOffsetDim] = linalgIndices[offsetDim]; - } - - // Now we add together our two indices to get the final index into the - // operand. - SmallVector combinedIndex; - for (int64_t i = 0; i < operandRank; ++i) - combinedIndex.push_back(rewriter.createOrFold( - loc, rewriter.getIndexType(), remappedIndexFromIndices[i], - indexFromOffset[i])); - - Value extractOperand; - if (isa(operand.getType())) { - extractOperand = operand; - } else { - // Cannot extract from unranked tensors, cast to ranked first. - SmallVector dims(operandRank, ShapedType::kDynamic); - auto type = RankedTensorType::get( - dims, cast(operand.getType()).getElementType()); - extractOperand = rewriter.create(loc, type, operand); - } - Value element = - rewriter.create(loc, extractOperand, combinedIndex); - rewriter.create(loc, element); - - rewriter.replaceOp(gatherOp, linalgOp.getResults()); - - return success(); - } -}; - -/// Converts xla-hlo.select_and_scatter op to a sequence of linalg.generics ops. -/// The current version computes the scattered index and populates the correct -/// value for each tile. It does not currently handle overlapping tiles. -struct SelectAndScatterNoOverlapConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::SelectAndScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - ImplicitLocOpBuilder b(loc, rewriter); - Value source = op.getSource(); - Value operand = op.getOperand(); - Value init = op.getInitValue(); - - auto sourceTy = llvm::dyn_cast(source.getType()); - auto operandTy = llvm::dyn_cast(operand.getType()); - auto initTy = llvm::dyn_cast(init.getType()); - auto resultTy = llvm::dyn_cast(op.getResult().getType()); - if (!sourceTy || !operandTy || !initTy || !resultTy) - return rewriter.notifyMatchFailure(op, "inputs/outputs must be ranked"); - - auto indexETy = b.getI32Type(); - auto srcETy = operandTy.getElementType(); - auto destETy = initTy.getElementType(); - - const int64_t rank = sourceTy.getRank(); - - llvm::SmallVector pad(rank * 2, 0); - if (op.getPadding().has_value()) - pad = llvm::to_vector(op.getPaddingAttr().getValues()); - - // TODO(suderman): Add support for padding. - if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) - return rewriter.notifyMatchFailure(op, "non-zero padding values found."); - - if (!op.getWindowStrides().has_value()) - return rewriter.notifyMatchFailure(op, "no window strides found"); - - if (!op.getWindowDimensions().has_value()) - return rewriter.notifyMatchFailure(op, "no window dimensions found"); - - auto strides = llvm::to_vector(op.getWindowStrides().value()); - auto window = llvm::to_vector(op.getWindowDimensions().value()); - - if (static_cast(strides.size()) != operandTy.getRank() || - static_cast(window.size()) != operandTy.getRank()) - return rewriter.notifyMatchFailure( - op, "stride/window length should equal operand rank"); - - // The current version cannot handle overlapped regions. - for (int i = 0, s = strides.size(); i < s; ++i) { - if (strides[i] < window[i]) - return rewriter.notifyMatchFailure( - op, "overlapping windows are not supported"); - } - - // If the window only contains a single element, this lowering will be - // problematic. Ultimately we should handle this with a canonicalizer. - if (llvm::all_of(window, [](auto sz) { return sz == 1; })) { - return rewriter.notifyMatchFailure(op, - "unary window size is not supported"); - } - - // The first linalg.generic operation computes the relevant index over - // window for the defined stablehlo.select_and_scatter. This involves - // iterating over the window of the operand a computing the index. - // Rather than storing N indices we compute the row major identifier - // in the window, to specify which location should be scattered to. - - // Output specifies the `rank` parallel iterators that correspond to - // output values. - SmallVector outputExprs; - for (int i = 0, s = rank; i < s; ++i) - outputExprs.push_back(b.getAffineDimExpr(i)); - - // For the output we need to define the reduction across the window - // width and height. This includes applying striding behavior and - // adding the additional reduction iterators. We skip length-1 dimensions - // as the reduction is degenerate. - SmallVector filteredWindows, filteredStrides; - SmallVector sourceExprs(outputExprs); - SmallVector windowExprs; - for (int i = 0, s = rank; i < s; ++i) { - sourceExprs[i] = sourceExprs[i] * strides[i]; - if (strides[i] != 1) { - auto expr = b.getAffineDimExpr(windowExprs.size() + sourceExprs.size()); - sourceExprs[i] = sourceExprs[i] + expr; - windowExprs.push_back(expr); - filteredWindows.push_back(window[i]); - filteredStrides.push_back(strides[i]); - } - } - - // Determine the total number of AffineExprs and construct the IndexingMaps - // for the windowed reduction operation. - const int64_t reduceExprCount = windowExprs.size() + sourceExprs.size(); - SmallVector reduceIndexingMaps; - reduceIndexingMaps.push_back(AffineMap::get(reduceExprCount, - /*symbolCount=*/0, sourceExprs, - rewriter.getContext())); - reduceIndexingMaps.push_back(AffineMap::get(reduceExprCount, - /*symbolCount=*/0, windowExprs, - rewriter.getContext())); - auto reduceOutMap = - AffineMap::get(reduceExprCount, - /*symbolCount=*/0, outputExprs, rewriter.getContext()); - reduceIndexingMaps.push_back(reduceOutMap); - reduceIndexingMaps.push_back(reduceOutMap); - - // Output sizes should match the dimensions of the `source` tensor, even if - // dynamic. - SmallVector reduceDynSizes; - for (int i = 0, s = rank; i < s; ++i) - if (sourceTy.isDynamicDim(i)) - reduceDynSizes.push_back(b.create(source, i)); - - Value reduceValueEmpty = - b.create(sourceTy.getShape(), destETy, reduceDynSizes); - Value reduceIndexEmpty = b.create( - sourceTy.getShape(), indexETy, reduceDynSizes); - - // We initialize indices to -1 which indicates no matching destination. - Value negativeOne = b.create(b.getI32IntegerAttr(-1)); - reduceIndexEmpty = - b.create(negativeOne, reduceIndexEmpty).getResult(0); - - // We only care to match the reduction dimensions. - Value windowEmpty = b.create(filteredWindows, srcETy); - - auto reduceGeneric = b.create( - /*resultTensors=*/ArrayRef{reduceValueEmpty.getType(), - reduceIndexEmpty.getType()}, - /*inputs=*/ValueRange{operand, windowEmpty}, - /*outputs=*/ValueRange{reduceValueEmpty, reduceIndexEmpty}, - reduceIndexingMaps, - getParallelAndReductionIterators(reduceExprCount, windowExprs.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // First we clone in the selection block. - auto &reduceRegion = reduceGeneric.getRegion(); - rewriter.setInsertionPoint(reduceGeneric); - rewriter.cloneRegionBefore(op.getSelect(), reduceRegion, - reduceRegion.end()); - - // This includes convert `stablehlo` scalar-tensor regions to `linalg` - // scalars. - TypeConverter::SignatureConversion reduceSignConverter(4); - reduceSignConverter.addInputs(0, srcETy); - reduceSignConverter.addInputs(srcETy); - reduceSignConverter.addInputs(1, destETy); - reduceSignConverter.addInputs(indexETy); - rewriter.applySignatureConversion(&reduceRegion.front(), - reduceSignConverter, getTypeConverter()); - - // Grab the terminator and use the turned value to now select the - // correct index and value. - auto &reduceBlock = reduceRegion.front(); - auto *reduceTerminator = reduceBlock.getTerminator(); - Value selectPred = reduceTerminator->getOperand(0); - Value selectInVal = reduceBlock.getArgument(0); - Value selectOutVal = reduceBlock.getArgument(2); - Value selectOutIdx = reduceBlock.getArgument(3); - - b.setInsertionPoint(reduceTerminator); - - // The predicate operates on scalar-tensors, so we need to extract the - // value for `linalg` operations. Tensor-ops are cleaned up by other - // rewriters. - selectPred = b.create(rewriter.getI1Type(), selectPred, - ValueRange{}); - - // We select if either the selection function returns `true` or the - // current reduction index is `-1`, e.g. no index has been selected yet. - Value selectNegOne = b.create(arith::CmpIPredicate::eq, - selectOutIdx, negativeOne); - selectPred = b.create(selectPred, selectNegOne); - - // We compute a unique idx for each element in the window. - Value computedIdx = b.create(rank); - for (int i = 1, s = filteredStrides.size(); i < s; ++i) { - Value width = b.create(filteredStrides[i]); - Value idx = b.create(rank + i); - computedIdx = b.create(width, computedIdx); - computedIdx = b.create(computedIdx, idx); - } - computedIdx = b.create(indexETy, computedIdx); - - // Using the selection predicate track the value and selected - // identifier for the future scattering. - Value selectedIdx = - b.create(selectPred, computedIdx, selectOutIdx); - Value selectedValue = - b.create(selectPred, selectInVal, selectOutVal); - b.create(ValueRange{selectedValue, selectedIdx}); - - // Original terminator is an stablehlo.return we no longer need. - rewriter.eraseOp(reduceTerminator); - b.setInsertionPoint(op); - - Value reduceIndex = reduceGeneric.getResult(1); - ShapedType reduceIndexTy = llvm::cast(reduceIndex.getType()); - - // For the second generic we restricted to only cases where there are - // no window overlaps. This guarantees that each source value is scattered - // within its own unique window. We can broadcast to this window size and - // populate only the relative location. - llvm::SmallVector broadcastShape; - llvm::SmallVector broadcastDynDims; - llvm::SmallVector broadcastExprs; - for (int i = 0, s = reduceIndexTy.getRank(); i < s; ++i) { - int64_t broadcast = strides[i]; - if (sourceTy.isDynamicDim(i)) - broadcastDynDims.push_back(b.create(source, i)); - - broadcastExprs.push_back(b.getAffineDimExpr(broadcastShape.size())); - broadcastShape.push_back(sourceTy.getDimSize(i)); - if (broadcast > 1) { - broadcastShape.push_back(broadcast); - } - } - - // We broadcast the values of our input tensors across the stride-tiling - // size. - Value scatterEmpty = b.create( - broadcastShape, resultTy.getElementType(), broadcastDynDims); - Value initScalar = b.create(initTy.getElementType(), - init, ValueRange{}); - Value scatterFill = - b.create(initScalar, scatterEmpty).getResult(0); - - // Both the indices and values are broadcasted using the same indexing map. - // Output fully parallel. - auto scatterInputMap = - AffineMap::get(broadcastShape.size(), /*symbolCount=*/0, broadcastExprs, - b.getContext()); - SmallVector scatterIndexingMaps = { - scatterInputMap, scatterInputMap, - b.getMultiDimIdentityMap(broadcastShape.size())}; - - auto scatterGeneric = b.create( - /*resultTensors=*/ArrayRef{scatterFill.getType()}, - /*inputs=*/ValueRange{reduceIndex, source}, - /*outputs=*/ValueRange{scatterFill}, scatterIndexingMaps, - getNParallelLoopsAttrs(broadcastShape.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Clone the scattering combination logic and perform the tensor-to-scalar - // conversion. - auto &scatterRegion = scatterGeneric.getRegion(); - b.setInsertionPoint(scatterGeneric); - rewriter.cloneRegionBefore(op.getScatter(), scatterRegion, - scatterRegion.end()); - - TypeConverter::SignatureConversion scatterSignConverter(4); - scatterSignConverter.addInputs(indexETy); - scatterSignConverter.addInputs(0, sourceTy.getElementType()); - scatterSignConverter.addInputs(1, sourceTy.getElementType()); - rewriter.applySignatureConversion(&scatterRegion.front(), - scatterSignConverter, getTypeConverter()); - - auto &scatterBlock = scatterRegion.front(); - auto scatterTerminator = scatterBlock.getTerminator(); - b.setInsertionPoint(scatterTerminator); - - Value scatterInputIdx = scatterBlock.getArgument(0); - Value scatterOutputVal = scatterBlock.getArgument(2); - Value scatterUpdate = b.create( - sourceTy.getElementType(), scatterTerminator->getOperand(0), - ValueRange{}); - - // Compute the index of the tiled region to determine if it was selected. - Value id = b.create(0); - int64_t dim = 0; - for (int i = 0, s = strides.size(); i < s; ++i) { - if (strides[i] > 1) { - Value idx = b.create(++dim); - Value tileSz = b.create(strides[i]); - id = b.create(id, tileSz); - id = b.create(id, idx); - } - ++dim; - } - - // Check whether the computed id matches the to-scatter id, then select and - // yield. - id = b.create(indexETy, id); - auto scatterPred = b.create( - b.getI1Type(), arith::CmpIPredicate::eq, id, scatterInputIdx); - scatterUpdate = - b.create(scatterPred, scatterUpdate, scatterOutputVal); - - b.create(scatterUpdate); - rewriter.eraseOp(scatterTerminator); - b.setInsertionPoint(op); - - // We now need to collapse the tiles back into their - // source dimensions. We collapse any of the broadcast regions together. - int64_t collapseDim = 0; - SmallVector reassociationMap; - for (int i = 0, s = window.size(); i < s; ++i) { - SmallVector dims = {collapseDim}; - if (strides[i] > 1) - dims.push_back(collapseDim + 1); - - reassociationMap.push_back(ReassociationIndices(dims)); - collapseDim += dims.size(); - } - - Value collapse = b.create( - scatterGeneric.getResult(0), reassociationMap); - auto collapseTy = llvm::cast(collapse.getType()); - - // After collapsing it it possible that the target may need to be padded. - auto zero = b.createOrFold(0); - SmallVector padShape; - SmallVector padLow, padHigh; - padLow.resize(operandTy.getRank(), zero); - - for (int i = 0, s = rank; i < s; ++i) { - int64_t size = std::max(resultTy.getDimSize(i), collapseTy.getDimSize(i)); - if (operandTy.isDynamicDim(i) || collapseTy.isDynamicDim(i)) - size = ShapedType::kDynamic; - padShape.push_back(size); - - Value in = b.create(collapse, i); - Value out = b.create(operand, i); - Value diff = b.create(out, in); - Value pad = b.createOrFold(diff, zero); - padHigh.push_back(pad); - } - - Value padded = b.create(collapseTy.clone(padShape), collapse, - padLow, padHigh, initScalar); - - // The result may exceed the target size, slice if necessary. - SmallVector sliceSizes; - SmallVector sliceOffsets(operandTy.getRank(), - b.getIndexAttr(0)); - SmallVector sliceStrides(operandTy.getRank(), - b.getIndexAttr(1)); - for (int i = 0, s = operandTy.getRank(); i < s; ++i) { - OpFoldResult dim = b.getIndexAttr(operandTy.getDimSize(i)); - if (operandTy.isDynamicDim(i)) - dim = b.createOrFold(operand, i); - sliceSizes.push_back(dim); - } - - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, padded, sliceOffsets, sliceSizes, sliceStrides); - - return success(); - } -}; - -// Decomposes a pad with negative edge padding into a pad without negative edge -// padding and a tensor.extract_slice. -struct PadOpNegativePaddingConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::PadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector padLow; - SmallVector padHigh; - SmallVector sliceStarts; - - bool hasNegativePadding = false; - for (int64_t low : op.getEdgePaddingLow()) { - if (low >= 0) { - padLow.push_back(low); - sliceStarts.push_back(rewriter.getIndexAttr(0)); - } else { - padLow.push_back(0); - sliceStarts.push_back(rewriter.getIndexAttr(-low)); - hasNegativePadding = true; - } - } - - for (int64_t high : op.getEdgePaddingHigh()) { - if (high >= 0) { - padHigh.push_back(high); - } else { - padHigh.push_back(-high); - hasNegativePadding = true; - } - } - - // If there's no negative edge padding we're done. - if (!hasNegativePadding) - return failure(); - - // Create a new pad op with the positive values. - Value pad = rewriter.create( - op.getLoc(), adaptor.getOperand(), adaptor.getPaddingValue(), - rewriter.getDenseI64ArrayAttr(padLow), - rewriter.getDenseI64ArrayAttr(padHigh), op.getInteriorPadding()); - - // Then slice according to the negative edge padding. Static shapes only for - // now. - if (!op.getType().hasStaticShape()) - return failure(); - SmallVector sizes( - llvm::map_range(op.getType().getShape(), [&](int64_t dim) { - return rewriter.getIndexAttr(dim); - })); - SmallVector strides(sliceStarts.size(), - rewriter.getIndexAttr(1)); - rewriter.replaceOpWithNewOp(op, pad, sliceStarts, - sizes, strides); - return success(); - } -}; - -/// Converts stablehlo.pad operation to tensor.pad or tensor.insert_slice. -struct PadOpConversion final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::PadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto resultType = - getTypeConverter()->convertType(op.getResult().getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - // Negative edge padding is decomposed separately. - auto isNegative = [](int64_t intVal) { return intVal < 0; }; - if (llvm::any_of(op.getEdgePaddingLow(), isNegative) || - llvm::any_of(op.getEdgePaddingHigh(), isNegative)) - return failure(); - - Value paddingVal = rewriter.createOrFold( - loc, adaptor.getPaddingValue()); - - auto i64ToFoldResult = [&](const int64_t &i) -> OpFoldResult { - return rewriter.getIntegerAttr(rewriter.getI64Type(), i); - }; - - // If there is no interior padding lower to tensor.pad directly. - if (llvm::all_of(op.getInteriorPadding(), - [](const int64_t &i) { return i == 0; })) { - auto padTensorOp = rewriter.create( - loc, resultType, adaptor.getOperand(), - llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult), - llvm::map_to_vector(op.getEdgePaddingHigh(), i64ToFoldResult), - paddingVal); - rewriter.replaceOp(op, padTensorOp.getResult()); - return success(); - } - - // We have interior padding, which can be lowered to tensor.insert_slice. - // Start by filling a result-sized tensor with the pad value. - auto emptyTensor = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - auto fill = - rewriter.create(loc, paddingVal, emptyTensor).result(); - - // Get sizes of the original operand. - auto operandType = llvm::cast(adaptor.getOperand().getType()); - auto sizes = llvm::map_to_vector( - llvm::seq(0, operandType.getRank()), - [&](int64_t dim) -> OpFoldResult { - if (!operandType.isDynamicDim(dim)) - return rewriter.getIndexAttr(operandType.getDimSize(dim)); - return rewriter.create(loc, adaptor.getOperand(), dim) - .getResult(); - }); - // Map interior padding to strides. - auto strides = llvm::map_to_vector( - op.getInteriorPadding(), [&](const int64_t &stride) -> OpFoldResult { - return rewriter.getIntegerAttr(rewriter.getI64Type(), stride + 1); - }); - - rewriter.replaceOpWithNewOp( - op, adaptor.getOperand(), fill, - llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult), sizes, - strides); - return success(); - } -}; - -/// Converts xla-hlo.torch_index_select op to a linalg.generic op. -struct TorchIndexSelectOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::TorchIndexSelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - int axis = static_cast(op.getDim()); - int batch = static_cast(op.getBatchDims()); - auto indexShapedType = llvm::cast(adaptor.getIndex().getType()); - int numIndices = static_cast(indexShapedType.getRank()); - auto operandShapedType = - llvm::cast(adaptor.getOperand().getType()); - if (axis < 0) - axis += static_cast(operandShapedType.getRank()); - if (batch < 0) - batch += numIndices; - - Location loc = op.getLoc(); - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - int rank = static_cast(resultType.getRank()); - - // The output shape is - // `params[:axis] + indices[batch_dims:] + params[axis + 1:]` - SmallVector dynSizes; - for (int i = 0; i < rank; ++i) { - if (!resultType.isDynamicDim(i)) - continue; - if (i < axis) { - dynSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), i)); - } else if (i < (axis + numIndices - batch)) { - int idx = i - axis + batch; - dynSizes.push_back( - rewriter.create(loc, adaptor.getIndex(), idx)); - } else { - int idx = i - (axis + numIndices - batch) + axis + 1; - dynSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), idx)); - } - } - - // Generate dummy tensor to preserve slice shape information. - SmallVector sliceShape; - SmallVector dynSliceSizes; - SmallVector sliceExprs; - ArrayRef resultShape = resultType.getShape(); - for (int i = 0; i < axis; ++i) { - sliceExprs.push_back(rewriter.getAffineDimExpr(i)); - sliceShape.push_back(resultShape[i]); - if (!resultType.isDynamicDim(i)) - continue; - dynSliceSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), i)); - } - for (int i = axis + numIndices - batch; i < rank; ++i) { - sliceExprs.push_back(rewriter.getAffineDimExpr(i)); - sliceShape.push_back(resultShape[i]); - if (!resultType.isDynamicDim(i)) - continue; - int idx = i - (axis + numIndices - batch) + axis + 1; - dynSliceSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), idx)); - } - - // Setup AffineMap for operand tensor. - SmallVector exprs; - for (int i = 0; i < batch; ++i) { - exprs.push_back(rewriter.getAffineDimExpr(i)); - } - for (int i = 0, e = numIndices - batch; i < e; ++i) { - exprs.push_back(rewriter.getAffineDimExpr(axis + i)); - } - - SmallVector indexingMaps; - indexingMaps.emplace_back( - AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext())); - indexingMaps.emplace_back(AffineMap::get( - rank, /*symbolCount=*/0, sliceExprs, rewriter.getContext())); - indexingMaps.emplace_back(rewriter.getMultiDimIdentityMap(rank)); - - Value sliceOp = rewriter.create( - loc, sliceShape, resultType.getElementType(), dynSliceSizes); - - Value emptyOp = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), dynSizes); - auto linalgOp = rewriter.create( - loc, /*resultTensors=*/ArrayRef{resultType}, - /*inputs=*/ValueRange{adaptor.getIndex(), sliceOp}, - /*outputs=*/emptyOp, indexingMaps, getNParallelLoopsAttrs(rank), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - SmallVector bodyArgTypes; - SmallVector linalgOpArgs = {adaptor.getIndex(), sliceOp}; - // Add a block to the region. - auto *region = &linalgOp.getRegion(); - auto *block = rewriter.createBlock(region, region->end()); - for (auto blockArgs : linalgOpArgs) { - bodyArgTypes.push_back( - llvm::cast(blockArgs.getType()).getElementType()); - } - block->addArguments(bodyArgTypes, - SmallVector(bodyArgTypes.size(), loc)); - block->addArguments(resultType.getElementType(), loc); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(block); - - Value castedValue = rewriter.create( - loc, rewriter.getIndexType(), block->getArgument(0)); - - SmallVector indices; - for (int i = 0; i < axis; ++i) { - indices.push_back(rewriter.create(loc, i)); - } - indices.push_back(castedValue); - for (int i = axis + numIndices - batch; i < rank; ++i) { - indices.push_back(rewriter.create(loc, i)); - } - Value res = - rewriter.create(loc, adaptor.getOperand(), indices); - rewriter.create(loc, res); - - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct SetDimensionSizeConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::SetDimensionSizeOp setDimensionSizeOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // We can lower SetDimensionSize to tensor extract. This turns into a - // regular dynamic shape. Note that the bounds annotation is still around - // but may be no longer valid depending on choices made by bufferization. - Location loc = setDimensionSizeOp.getLoc(); - auto resultType = dyn_cast(setDimensionSizeOp.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(setDimensionSizeOp, - "expected a ranked tensor"); - - SmallVector offsets(resultType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector strides(resultType.getRank(), - rewriter.getIndexAttr(1)); - SmallVector sizes(llvm::map_range( - resultType.getShape(), [&](int64_t dim) -> OpFoldResult { - return rewriter.getIndexAttr(dim); - })); - Value dimensionSize = - rewriter.create(loc, setDimensionSizeOp.getSize()); - sizes[setDimensionSizeOp.getDimension()] = - rewriter - .create(loc, rewriter.getIndexType(), - dimensionSize) - .getResult(); - - rewriter.replaceOpWithNewOp( - setDimensionSizeOp, resultType, adaptor.getOperand(), offsets, sizes, - strides); - return success(); - } -}; - -struct ConvertStableHloToLinalg final - : impl::ConvertStableHloToLinalgBase { - using ConvertStableHloToLinalgBase::ConvertStableHloToLinalgBase; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext &ctx = getContext(); - RewritePatternSet patterns(&ctx); - ConversionTarget target(ctx); - target.addLegalDialect< - bufferization::BufferizationDialect, arith::ArithDialect, - complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, - tensor::TensorDialect, sparse_tensor::SparseTensorDialect, - scf::SCFDialect, shape::ShapeDialect>(); - - target.addLegalOp(); - - auto typeConverter = createStableHloToLinalgTypeConverter(); - ModuleOp module = getOperation(); - - populateStableHloToLinalgConversionPatterns(&ctx, *typeConverter, &patterns, - this->enablePrimitiveOps); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -void populateStableHloToLinalgConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet *patterns, - bool enablePrimitiveOps) { - // clang-format off - patterns->add< - BitcastConvertConverter, - ConcatenateConverter, - ConstConverterTensor, - EinsumToLinalgConverter, - GatherConversion, - RealDynamicSliceConverter, - ReshapeOpConverter, - ReverseConverter, - SetDimensionSizeConverter, - SliceConverter, - DynamicSliceConverter, - DynamicUpdateSliceConverter, - PadOpConversion, - PadOpNegativePaddingConversion, - TorchIndexSelectOpConversion, - SelectAndScatterNoOverlapConverter - >(typeConverter, context); - - detail::populatePointwiseStableHloToLinalgConversionPatterns( - context, typeConverter, patterns, enablePrimitiveOps); - - if (enablePrimitiveOps) { - patterns->add< - BroadcastInDimOpToBroadcastConverter, - BroadcastOpToBroadcastConverter, - DynamicBroadcastInDimOpToBroadcastConverter, - IotaToMapConverter, - IotaToMapConverter, - MapOpToMapConverter, - TransposeOpToTransposeConverter - >(typeConverter, context); - } else { - patterns->add< - BroadcastConverter, - IotaConverter, - IotaConverter, - HloBroadcastInDimConverter, - HloDynamicBroadcastInDimConverter, - MapOpToGenericConverter, - TransposeConverter - >(typeConverter, context); - } - - // clang-format on - - detail::populateStableHloConvolutionToLinalgConversionPatterns( - context, typeConverter, patterns); - detail::populateStableHloDotProdToLinalgConversionPatterns( - context, typeConverter, patterns); - detail::populateStableHloRandomToLinalgConversionPatterns( - context, typeConverter, patterns); - detail::populateStableHloReductionToLinalgConversionPatterns( - context, typeConverter, patterns, enablePrimitiveOps); - detail::populateScalarHloToArithConversionPatterns( - context, typeConverter, patterns, isInBodyOfLinalgOps); - linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns); -} - -std::unique_ptr createStableHloToLinalgTypeConverter() { - return std::make_unique(); -} - -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp deleted file mode 100644 index f51fcd66399a..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp +++ /dev/null @@ -1,807 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO convolution ops to Linalg dialect. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -/// Apply dilation and padding to the input of a convolution. -Value applyConvolutionPadding(Location loc, Value input, - DenseIntElementsAttr padding, - std::optional> lhsDilation, - llvm::ArrayRef dimMappings, - OpBuilder &rewriter) { - SmallVector lhsDilationValues; - if (lhsDilation.has_value()) - lhsDilationValues = llvm::to_vector(lhsDilation.value()); - bool noPadding = !padding || isSplatValue(padding, 0); - bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1); - if (noPadding && noDilation) - return input; - - auto inputType = cast(input.getType()); - int64_t rank = inputType.getRank(); - - // Translate window padding into low/high padding. - SmallVector padLow(rank, 0); - SmallVector padHigh(rank, 0); - if (padding) { - // The padding attribute contains two values per dimension, but excludes the - // batch and feature dimensions. - assert(rank * 2 == padding.size() + 4 && - "There should be 2 padding values per dimension, i.e low and high."); - for (int64_t i : llvm::seq(0, padding.size() / 2)) { - int64_t dim = dimMappings[i]; - padLow[dim] = padding.getValues()[i * 2]; - padHigh[dim] = padding.getValues()[i * 2 + 1]; - } - } - - // Translate input dilation into interior padding. - SmallVector padInterior(rank, 0); - if (lhsDilation) { - assert(rank == static_cast(lhsDilationValues.size()) + 2); - for (int64_t i : llvm::seq(0, lhsDilationValues.size())) { - int64_t dim = dimMappings[i]; - padInterior[dim] = lhsDilationValues[i] - 1; - } - } - - Value zero; - if (auto complexType = dyn_cast(inputType.getElementType())) { - auto zeroElement = rewriter.getZeroAttr(complexType.getElementType()); - auto zeroAttr = rewriter.getArrayAttr({zeroElement, zeroElement}); - zero = rewriter.create(loc, complexType, zeroAttr); - zero = rewriter.create( - loc, RankedTensorType::get({}, complexType), zero); - } else { - zero = rewriter.create( - loc, rewriter.getZeroAttr( - RankedTensorType::get({}, inputType.getElementType()))); - } - - return rewriter.create(loc, input, zero, padLow, - padHigh, padInterior); -} - -/// If the ConvolutionOp has a window reversal, applies it to the filter. -Value applyConvolutionReversal(Location loc, OpBuilder &b, - mlir::stablehlo::ConvolutionOp op, - Value filter) { - std::optional reversals = op.getWindowReversal(); - if (!reversals.has_value()) { - return filter; - } - llvm::SmallVector reversedDims; - for (auto [idx, reversed] : llvm::enumerate(reversals.value())) { - if (reversed) { - reversedDims.push_back( - op.getDimensionNumbers().getKernelSpatialDimensions()[idx]); - } - } - - return b.create( - loc, filter, b.getDenseI64ArrayAttr(reversedDims)); -} - -/// Returns true if the given `dimensionNumbers` from a stablehlo.convolution op -/// follows a canonical form: -/// -/// * Input dimensions have order: (batch_count, spatial_dims, -/// input_channel_count). -/// * Filter dimensions have order: (spatial_dims, input_channel_count, -/// output_channel_count). -/// * Output dimensions have order: (batch_count, spatial_dims, -/// output_channel_count). -bool hasCanonicalDimensionNumbers( - mlir::stablehlo::ConvDimensionNumbersAttr dimensionNumbers) { - const int64_t inputSpatialRank = - dimensionNumbers.getInputSpatialDimensions().size(); - // The dimensions for input should follow the order of - // batch_count, spatial_dims..., input_feature_count. - if (dimensionNumbers.getInputBatchDimension() != 0 || - dimensionNumbers.getInputFeatureDimension() != (inputSpatialRank + 1)) { - return false; - } - - const int64_t kernelSpatialRank = - dimensionNumbers.getKernelSpatialDimensions().size(); - // The dimensions for filter should follow the order of - // spatial_dims..., input_feature_count, num_output_feature_count. - if (dimensionNumbers.getKernelInputFeatureDimension() != kernelSpatialRank || - dimensionNumbers.getKernelOutputFeatureDimension() != - (kernelSpatialRank + 1)) { - return false; - } - - const int64_t outputSpatialRank = - dimensionNumbers.getOutputSpatialDimensions().size(); - // The dimensions for output should follow the order of - // batch_count, spatial_dims.., output_feature_count. - if (dimensionNumbers.getOutputBatchDimension() != 0 || - dimensionNumbers.getOutputFeatureDimension() != (outputSpatialRank + 1)) { - return false; - } - - if (inputSpatialRank != outputSpatialRank || - inputSpatialRank != kernelSpatialRank) { - return false; - } - - const int64_t *inputSpatialDim = - dimensionNumbers.getInputSpatialDimensions().data(); - const int64_t *kernelSpatialDim = - dimensionNumbers.getKernelSpatialDimensions().data(); - const int64_t *outputSpatialDim = - dimensionNumbers.getOutputSpatialDimensions().data(); - // Check spatial dims are ordered correctly. - for (int64_t i = 0; i < inputSpatialRank; ++i) { - const int64_t dim = i + 1; - if ((*inputSpatialDim++) != dim || (*outputSpatialDim++) != dim || - (*kernelSpatialDim++) != i) { - return false; - } - } - - return true; -} - -/// Converts stablehlo.conv operation to linalg named op. This only covers -/// normal convolution cases. The op must have canonical dimension numbers. -/// Depthwise convolution and pointwise convolution are not handled in the -/// conversion. -struct NormalConvolutionOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!hasCanonicalDimensionNumbers(op.getDimensionNumbers())) { - return failure(); - } - if (op.getFeatureGroupCount() != 1u) - return failure(); - if (op.getBatchGroupCount() != 1u) - return failure(); - - Location loc = op.getLoc(); - Value input = adaptor.getLhs(); - Value filter = adaptor.getRhs(); - filter = applyConvolutionReversal(loc, rewriter, op, filter); - auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - int64_t rank = resultType.getRank(); - - // Immediately emit an EmptyOp for output tensors with zero dimension. - if (llvm::is_contained(resultType.getShape(), 0)) { - rewriter.replaceOpWithNewOp(op, resultType.getShape(), - resultType.getElementType()); - return success(); - } - - // The output shape is N spatial_dims F. - SmallVector dynSizes; - if (resultType.isDynamicDim(0)) { - dynSizes.push_back(rewriter.create(loc, input, 0)); - } - for (int64_t i = 1, e = rank - 1; i < e; ++i) { - if (resultType.isDynamicDim(i)) { - return rewriter.notifyMatchFailure( - op, "expected output spatial dims to be static shapes"); - } - } - if (resultType.isDynamicDim(rank - 1)) { - dynSizes.push_back(rewriter.create(loc, filter, rank - 1)); - } - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), dynSizes); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - linalg::LinalgOp res; - Attribute strides; - if (auto s = op.getWindowStrides()) - strides = rewriter.getI64TensorAttr(*s); - Attribute dilations; - if (auto d = op.getRhsDilation()) - dilations = rewriter.getI64TensorAttr(*d); - - // Apply padding and input dilation. - llvm::SmallVector spatialDimMapping(rank - 2); - std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1); - input = applyConvolutionPadding(loc, input, op.getPaddingAttr(), - op.getLhsDilation(), spatialDimMapping, - rewriter); - - switch (rank) { - case 2: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - linalg::getPrunedAttributeList(op)); - break; - } - case 3: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - strides, dilations, linalg::getPrunedAttributeList(op)); - break; - } - case 4: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - strides, dilations, linalg::getPrunedAttributeList(op)); - break; - } - case 5: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - strides, dilations, linalg::getPrunedAttributeList(op)); - break; - } - default: { - return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op"); - } - } - rewriter.replaceOp(op, res.getOperation()->getResults()); - return success(); - } -}; - -/// Handles all possible inputs for the mlir::stablehlo::ConvolutionOp -struct ConvolutionOpGeneralConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - /// This lowering proceeds with the following steps: - /// 1. Handle padding and dilation of the input - /// 2. Handle padding and dilation of the window - /// 3. Handle reversal of the window - /// 4. If feature_group_count != 1: - /// - Reshape the input feature dimension, kernel output feature dimension, - /// and output feature dimension. - /// - Create the AffineExpr for the new dimension - /// - Conceptually, this splits the input feature and both output feature - /// dimensions and computes sets of convolutions with these partial views - /// of the values as if they were multiple convolutions combined in a - /// batch. - /// 5: If batch_group_count != 1: - /// - Reshape the input batch dimension, kernel output feature dimension, - /// and output feature dimension. - /// - Create the AffineExpr for the new dimension - /// - Conceptually, this splits the input batch and both output feature - /// dimensions and computes sets of convolutions with these partial views - /// of the values as if they were multiple convolutions combined in a - /// batch. - /// 6. For all dimensions not newly created by a reshape, create the - /// appropriate parallel and reduction dimensions to create a convolution. - /// 7. Create the linalg.generic that computes the multiply-add - /// 8. Reshape the output to the original shape if it was reshaped by the - /// feature or group count attributes. - LogicalResult - matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = op.getContext(); - - auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - auto reshapedResultShape = resultType.getShape().vec(); - if (!resultType.hasStaticShape()) - return failure(); - - // Immediately emit an EmptyOp for output tensors with zero dimension. - if (llvm::is_contained(reshapedResultShape, 0)) { - rewriter.replaceOpWithNewOp(op, reshapedResultShape, - resultType.getElementType()); - return success(); - } - - mlir::stablehlo::ConvDimensionNumbersAttr dimensionNumbers = - op.getDimensionNumbers(); - int64_t inputBatchDimension = dimensionNumbers.getInputBatchDimension(); - int64_t inputFeatureDimension = dimensionNumbers.getInputFeatureDimension(); - ArrayRef inputSpatialDimensions = - dimensionNumbers.getInputSpatialDimensions(); - - int64_t kernelInputFeatureDimension = - dimensionNumbers.getKernelInputFeatureDimension(); - int64_t kernelOutputFeatureDimension = - dimensionNumbers.getKernelOutputFeatureDimension(); - ArrayRef kernelSpatialDimensions = - dimensionNumbers.getKernelSpatialDimensions(); - - int64_t outputBatchDimension = dimensionNumbers.getOutputBatchDimension(); - int64_t outputFeatureDimension = - dimensionNumbers.getOutputFeatureDimension(); - ArrayRef outputSpatialDimensions = - dimensionNumbers.getOutputSpatialDimensions(); - - size_t featureGroupCount = op.getFeatureGroupCount(); - size_t batchGroupCount = op.getBatchGroupCount(); - - if (op.getFeatureGroupCount() != 1 && op.getBatchGroupCount() != 1) { - return rewriter.notifyMatchFailure( - op, "only one of feature and batch group counts can be non-one"); - } - - // Decompose the convolution into an initial padding - Value modifiedLhs = applyConvolutionPadding( - op.getLoc(), adaptor.getLhs(), adaptor.getPaddingAttr(), - adaptor.getLhsDilation(), - op.getDimensionNumbers().getInputSpatialDimensions(), rewriter); - Value modifiedRhs = applyConvolutionPadding( - op.getLoc(), adaptor.getRhs(), nullptr, adaptor.getRhsDilation(), - op.getDimensionNumbers().getKernelSpatialDimensions(), rewriter); - modifiedRhs = applyConvolutionReversal(loc, rewriter, op, modifiedRhs); - - // Non-one values for feature or batch group counts will result in reshaped - // inputs and outputs. These mappings are used to keep track of the the new - // index after reshaping has possibly inserted new dimensions. - auto paddedLhsType = cast(modifiedLhs.getType()); - auto paddedRhsType = cast(modifiedRhs.getType()); - SmallVector lhsIndexMapping(paddedLhsType.getRank()); - std::iota(lhsIndexMapping.begin(), lhsIndexMapping.end(), 0); - SmallVector rhsIndexMapping(paddedRhsType.getRank()); - std::iota(rhsIndexMapping.begin(), rhsIndexMapping.end(), 0); - SmallVector resultIndexMapping(resultType.getRank()); - std::iota(resultIndexMapping.begin(), resultIndexMapping.end(), 0); - auto updateDimMappingFromOffset = - [](llvm::SmallVectorImpl &mapping, int64_t offset) { - for (auto &mappingElt : llvm::drop_begin(mapping, offset)) { - mappingElt += 1; - } - }; - - // The rest of this code prepares the inputs and a single linalg::GenericOp - // to execute the convolution. The final linalg::GenericOp will be iterated - // through based on the following eventual maps. - SmallVector srcExprs(paddedLhsType.getRank()); - SmallVector windowExprs(paddedRhsType.getRank()); - SmallVector dstExprs(reshapedResultShape.size()); - int64_t nextDim = 0; - int64_t rank = resultType.getRank(); - - auto reshapeShapeVector = [](llvm::ArrayRef oldShape, - llvm::SmallVectorImpl &newShape, - int64_t reshapedDim, int64_t factor) { - newShape.reserve(oldShape.size() + 1); - for (int64_t i : llvm::seq(0, oldShape.size())) { - if (i == reshapedDim) { - newShape.push_back(factor); - newShape.push_back(oldShape[reshapedDim] / factor); - } else { - newShape.push_back(oldShape[i]); - } - } - }; - - // If batch or feature count groupings exist, represent this through - // reshaping the input to have an additional dimension that these groupings - // exist along, and reduce in that dimension - SmallVector iterationLoops; - if (featureGroupCount != 1) { - AffineExpr parallelDim = mlir::getAffineDimExpr(nextDim++, ctx); - iterationLoops.push_back(utils::IteratorType::parallel); - // Reshape LHS - { - srcExprs.insert(srcExprs.begin() + inputFeatureDimension, parallelDim); - auto prevDimsRef = paddedLhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, inputFeatureDimension, - featureGroupCount); - updateDimMappingFromOffset(lhsIndexMapping, inputFeatureDimension); - modifiedLhs = rewriter.create( - loc, - RankedTensorType::get(newShape, paddedLhsType.getElementType()), - modifiedLhs); - } - - // Reshape RHS - { - windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension, - parallelDim); - auto prevDimsRef = paddedRhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension, - featureGroupCount); - updateDimMappingFromOffset(rhsIndexMapping, - kernelOutputFeatureDimension); - modifiedRhs = rewriter.create( - loc, - RankedTensorType::get(newShape, paddedRhsType.getElementType()), - modifiedRhs); - } - // Prepare reshaped output shape - { - dstExprs.insert(dstExprs.begin() + outputFeatureDimension, parallelDim); - updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension); - reshapedResultShape.insert(reshapedResultShape.begin() + - outputFeatureDimension, - featureGroupCount); - reshapedResultShape[outputFeatureDimension + 1] /= featureGroupCount; - } - } - - if (batchGroupCount != 1) { - iterationLoops.push_back(utils::IteratorType::parallel); - AffineExpr parallelDim = mlir::getAffineDimExpr(nextDim++, ctx); - // Reshape LHS - { - srcExprs.insert(srcExprs.begin() + inputBatchDimension, parallelDim); - ArrayRef prevDimsRef = paddedLhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, inputBatchDimension, - batchGroupCount); - updateDimMappingFromOffset(lhsIndexMapping, inputBatchDimension); - modifiedLhs = rewriter.create( - op.getLoc(), - RankedTensorType::get(newShape, paddedLhsType.getElementType()), - modifiedLhs); - } - - // Reshape RHS - { - windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension, - parallelDim); - ArrayRef prevDimsRef = paddedRhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension, - batchGroupCount); - updateDimMappingFromOffset(rhsIndexMapping, - kernelOutputFeatureDimension); - modifiedRhs = rewriter.create( - op.getLoc(), - RankedTensorType::get(newShape, paddedRhsType.getElementType()), - modifiedRhs); - } - // Prepare reshaped output shape - { - int64_t outputFeatureDim = resultIndexMapping[outputFeatureDimension]; - dstExprs.insert(dstExprs.begin() + outputFeatureDim, parallelDim); - updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension); - reshapedResultShape.insert( - reshapedResultShape.begin() + outputFeatureDim, batchGroupCount); - reshapedResultShape[outputFeatureDim + 1] /= batchGroupCount; - } - } - - // Handle input feature dimension - { - iterationLoops.push_back(utils::IteratorType::reduction); - AffineExpr inputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx); - srcExprs[lhsIndexMapping[inputFeatureDimension]] = inputFeatureDim; - windowExprs[rhsIndexMapping[kernelInputFeatureDimension]] = - inputFeatureDim; - } - - // Handle output feature dimension - { - iterationLoops.push_back(utils::IteratorType::parallel); - AffineExpr outputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx); - dstExprs[resultIndexMapping[outputFeatureDimension]] = outputFeatureDim; - windowExprs[rhsIndexMapping[kernelOutputFeatureDimension]] = - outputFeatureDim; - } - - // Handle spatial Dimensions - int64_t numSpatialDims = rank - 2; - for (int64_t i = 0; i < numSpatialDims; ++i) { - iterationLoops.push_back(utils::IteratorType::parallel); - iterationLoops.push_back(utils::IteratorType::reduction); - AffineExpr dim0 = mlir::getAffineDimExpr(nextDim++, ctx); - AffineExpr dim1 = mlir::getAffineDimExpr(nextDim++, ctx); - - AffineExpr stride = dim0; - if (op.getWindowStrides().has_value()) - stride = stride * op.getWindowStrides().value()[i]; - AffineExpr srcExpr = stride + dim1; - - srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr; - dstExprs[resultIndexMapping[outputSpatialDimensions[i]]] = dim0; - windowExprs[rhsIndexMapping[kernelSpatialDimensions[i]]] = dim1; - } - - // Handle batch dimension - { - iterationLoops.push_back(utils::IteratorType::parallel); - AffineExpr batchDim = mlir::getAffineDimExpr(nextDim++, ctx); - - srcExprs[lhsIndexMapping[inputBatchDimension]] = batchDim; - dstExprs[resultIndexMapping[outputBatchDimension]] = batchDim; - } - - // Finally, create the computation - auto inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); - - Value emptyTensor = rewriter.create( - loc, reshapedResultShape, resultType.getElementType()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - - Value convolved = - rewriter - .create( - loc, - /*resultTensors=*/ - llvm::ArrayRef(zeroTensor.getType()), - /*inputs=*/ - llvm::ArrayRef({modifiedLhs, modifiedRhs}), - /*outputs=*/llvm::ArrayRef(zeroTensor), inferredMaps, - iterationLoops, - /*bodyBuild=*/ - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange) { - ImplicitLocOpBuilder builder(nestedLoc, nestedBuilder); - linalg::Conv2DOp::regionBuilder( - builder, *builder.getInsertionBlock(), {}); - }, - linalg::getPrunedAttributeList(op)) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, - convolved); - - return success(); - } -}; - -/// Converts stablehlo.convolution operation to -/// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or -/// depthwise_conv_2d_input_nhwc_filter_hwc op. -struct DepthwiseConvolutionOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getBatchGroupCount() != 1) - return failure(); - // Fall into the normal convolution cases. - if (op.getFeatureGroupCount() == 1) - return failure(); - - const mlir::stablehlo::ConvDimensionNumbersAttr &dimensionNumbers = - op.getDimensionNumbers(); - const int64_t spatialRank = - dimensionNumbers.getInputSpatialDimensions().size(); - if (spatialRank == 0 || spatialRank > 3) { - return rewriter.notifyMatchFailure(op, "only support up to 3D for now"); - } - - // Make sure that this is depthwise convolution. - int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension(); - int64_t inputFeatureCount = - cast(op.getLhs().getType()).getDimSize(inputFeatureDim); - if (static_cast(op.getFeatureGroupCount()) != inputFeatureCount) { - return rewriter.notifyMatchFailure(op, "not depth-wise convolution"); - } - - // Make sure that this convolution has a canonical form. - if (!hasCanonicalDimensionNumbers(dimensionNumbers)) { - return rewriter.notifyMatchFailure(op, "does not have canonical form"); - } - - Attribute windowStrides; - if (op.getWindowStrides()) { - windowStrides = rewriter.getI64TensorAttr(op.getWindowStrides().value()); - } else { - windowStrides = SplatElementsAttr::get( - VectorType::get({spatialRank}, rewriter.getI64Type()), - rewriter.getI64IntegerAttr(1)); - } - - Attribute rhsDilation; - if (op.getRhsDilation()) { - rhsDilation = rewriter.getI64TensorAttr(op.getRhsDilation().value()); - } else { - rhsDilation = SplatElementsAttr::get( - VectorType::get({spatialRank}, rewriter.getI64Type()), - rewriter.getI64IntegerAttr(1)); - } - - Location loc = op.getLoc(); - Value input = adaptor.getLhs(); - Value filter = adaptor.getRhs(); - auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - if (!resultType.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, - "expected output has static shapes"); - } - - // Immediately emit an EmptyOp for output tensors with zero dimension. - if (llvm::is_contained(resultType.getShape(), 0)) { - rewriter.replaceOpWithNewOp(op, resultType.getShape(), - resultType.getElementType()); - return success(); - } - - // Apply padding and input dilation. - llvm::SmallVector spatialDimMapping(spatialRank); - std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1); - input = applyConvolutionPadding(loc, input, op.getPaddingAttr(), - op.getLhsDilation(), spatialDimMapping, - rewriter); - - auto filterDims = - llvm::to_vector(cast(op.getRhs().getType()).getShape()); - - auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) { - SmallVector reassociations; - int64_t rank = cast(v.getType()).getRank(); - for (int64_t i = 0; i < rank - 1; ++i) - reassociations.emplace_back(1, i); - reassociations.back().push_back(rank - 1); - return reassociations; - }; - - int64_t kernelInputFeatureDimension = - dimensionNumbers.getKernelInputFeatureDimension(); - int64_t kernelOutputFeatureDimension = - dimensionNumbers.getKernelOutputFeatureDimension(); - if (filterDims[kernelInputFeatureDimension] * - filterDims[kernelOutputFeatureDimension] != - static_cast(op.getFeatureGroupCount())) { - // For cases where channel multiplier != 1 - - // Reshaping filter shape - // [filter_height, filter_width, 1, kernel-output-feature]. - // to - // [filter_height, filter_width, feature_group_count, - // kernel-output-feature/feature_group_count ] - SmallVector reshapedFilterDims; - reshapedFilterDims.assign(filterDims.begin(), filterDims.end()); - Value reshapedFilter = filter; - if (filterDims[kernelInputFeatureDimension] == 1) { - reshapedFilterDims[kernelInputFeatureDimension] = - op.getFeatureGroupCount(); - reshapedFilterDims[kernelOutputFeatureDimension] /= - op.getFeatureGroupCount(); - auto reshapedFilterType = RankedTensorType::get( - reshapedFilterDims, - cast(op.getRhs().getType()).getElementType()); - - reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter); - } - - ArrayRef outputDims = resultType.getShape(); - int64_t channelMultiplier = reshapedFilterDims.back(); - SmallVector reshapedOutputDims; - reshapedOutputDims.assign(outputDims.begin(), outputDims.end()); - reshapedOutputDims.push_back(channelMultiplier); - reshapedOutputDims[reshapedOutputDims.size() - 2] /= channelMultiplier; - - Value emptyTensor = rewriter.create( - loc, reshapedOutputDims, resultType.getElementType()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - - auto reshapedOutputType = RankedTensorType::get( - reshapedOutputDims, resultType.getElementType()); - Value conv; - switch (spatialRank) { - case 1: { - conv = - rewriter - .create( - loc, reshapedOutputType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)) - .getResult(0); - break; - } - case 2: { - conv = - rewriter - .create( - loc, reshapedOutputType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)) - .getResult(0); - break; - } - case 3: { - conv = - rewriter - .create( - loc, reshapedOutputType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)) - .getResult(0); - break; - } - default: - llvm_unreachable("Unhandled case"); - } - - // Create a Linalg reshape op that converts the output from 5 dimensions - // into 4 dimensions (by collapsing the last two dimensions). This is - // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns - // 5 dimensions for the output. - rewriter.replaceOpWithNewOp( - op, resultType, conv, - getReassociationIndicesToCollapseLastTwoDims(conv)); - } else { - // For cases where channel multiplier == 1 - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - - // Create a Linalg reshape op that converts the filter from 4 dimensions - // into 3 dimensions (by droping the unit dimension). This is needed - // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3 - // dimensions for the filter. - - filterDims[filterDims.size() - 2] = - static_cast(op.getFeatureGroupCount()); - filterDims.pop_back(); - - RankedTensorType filterShape = - RankedTensorType::get(filterDims, op.getType().getElementType()); - - Value reshapedFilter = rewriter.create( - loc, filterShape, filter, - getReassociationIndicesToCollapseLastTwoDims(filter)); - - switch (spatialRank) { - case 1: - rewriter.replaceOpWithNewOp( - op, resultType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)); - break; - case 2: - rewriter.replaceOpWithNewOp( - op, resultType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)); - break; - case 3: - rewriter.replaceOpWithNewOp( - op, resultType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)); - break; - } - } - - return success(); - } -}; - -} // namespace - -namespace detail { -void populateStableHloConvolutionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - // Ensure specialized patterns are higher priority than their generic - // versions. - patterns - ->add( - typeConverter, context, PatternBenefit(2)); - - patterns->add(typeConverter, context); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgDotProd.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgDotProd.cpp deleted file mode 100644 index 0db4fafec6bc..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgDotProd.cpp +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO dot product ops to Linalg dialect. -// These patterns are separated out to their own file to save on the compilation -// times, given that we instantiate a large number of class templates here. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -enum class DotOperationType { - kVectorDot = 0, - kMatrixVector, - kVectorMatrix, - kMatrixMatrix, - kUnsupported -}; - -DotOperationType getDotOperationType(mlir::stablehlo::DotOp dotOp) { - ArrayRef lhsShape = - cast(dotOp.getLhs().getType()).getShape(); - ArrayRef rhsShape = - cast(dotOp.getRhs().getType()).getShape(); - auto shapeMatches = [](int64_t a, int64_t b) { - return ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b; - }; - if (lhsShape.size() == 1 && rhsShape.size() == 1 && - shapeMatches(lhsShape[0], rhsShape[0])) { - return DotOperationType::kVectorDot; - } - if (lhsShape.size() == 2 && rhsShape.size() == 1 && - shapeMatches(lhsShape[1], rhsShape[0])) { - return DotOperationType::kMatrixVector; - } - if (lhsShape.size() == 1 && rhsShape.size() == 2 && - shapeMatches(lhsShape[0], rhsShape[0])) { - return DotOperationType::kVectorMatrix; - } - if (lhsShape.size() == 2 && rhsShape.size() == 2 && - shapeMatches(lhsShape[1], rhsShape[0])) { - return DotOperationType::kMatrixMatrix; - } - return DotOperationType::kUnsupported; -} - -SmallVector getDotOpEmptyTensorDynSizes(OpBuilder &b, Location loc, - Value lhs, Value rhs, - DotOperationType type) { - SmallVector dynShape; - switch (type) { - case DotOperationType::kMatrixMatrix: { - if (llvm::cast(lhs.getType()).isDynamicDim(0)) - dynShape.push_back(b.create(loc, lhs, 0)); - if (llvm::cast(rhs.getType()).isDynamicDim(1)) - dynShape.push_back(b.create(loc, rhs, 1)); - break; - } - case DotOperationType::kMatrixVector: { - if (llvm::cast(lhs.getType()).isDynamicDim(0)) - dynShape.push_back(b.create(loc, lhs, 0)); - break; - } - case DotOperationType::kVectorMatrix: { - if (llvm::cast(rhs.getType()).isDynamicDim(1)) - dynShape.push_back(b.create(loc, rhs, 1)); - break; - } - case DotOperationType::kVectorDot: - case DotOperationType::kUnsupported: - break; - } - return dynShape; -} - -template -struct DotOpConversion final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = mlir::stablehlo::DotOp::Adaptor; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) { - return failure(); - } - if (getDotOperationType(op) != op_type) - return failure(); - - Location loc = op.getLoc(); - // Convert unsigned to signed. This works because signed and unsigned - // integer matmul is the same operation in two's complement. - auto outputType = - cast(getTypeConverter()->convertType(op.getType())); - SmallVector dynShape = getDotOpEmptyTensorDynSizes( - rewriter, loc, adaptor.getLhs(), adaptor.getRhs(), op_type); - Value emptyTensor = - !sparse_tensor::getSparseTensorEncoding(outputType) - ? getEmptyTensor(rewriter, loc, outputType, dynShape) - : getEmptySparseTensor(rewriter, loc, outputType, dynShape); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - rewriter.replaceOpWithNewOp( - op, TypeRange{outputType}, - ValueRange{adaptor.getLhs(), adaptor.getRhs()}, ValueRange{zeroTensor}, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct DotGeneralBatchMatMulOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) { - return failure(); - } - if (llvm::cast(op.getType()).getRank() != 3) { - return rewriter.notifyMatchFailure(op, "expected a batch matmul"); - } - - mlir::stablehlo::DotDimensionNumbersAttr dimNumbers = - op.getDotDimensionNumbers(); - ArrayRef lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); - ArrayRef rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); - ArrayRef lhsContractingDims = - dimNumbers.getLhsContractingDimensions(); - ArrayRef rhsContractingDims = - dimNumbers.getRhsContractingDimensions(); - if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) { - return rewriter.notifyMatchFailure( - op, "expected lhs batching dimensions exactly {0}"); - } - if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) { - return rewriter.notifyMatchFailure( - op, "expected rhs batching dimensions exactly {0}"); - } - if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) { - return rewriter.notifyMatchFailure( - op, "expected lhs contracting dimensions exactly {2}"); - } - if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) { - return rewriter.notifyMatchFailure( - op, "expected rhs contracting dimensions exactly {1}"); - } - - Location loc = op.getLoc(); - // Convert unsigned to signed. This works because signed and unsigned - // integer matmul is the same operation in two's complement. - auto outputType = - cast(typeConverter->convertType(op.getType())); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - Operation *linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/TypeRange{outputType}, - /*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()}, - /*outputBuffers=*/ValueRange{zeroTensor}, - linalg::getPrunedAttributeList(op)); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; - -struct DotGeneralOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) { - return failure(); - } - - // Get various dimension iterator information - mlir::stablehlo::DotDimensionNumbersAttr dimNumbers = - op.getDotDimensionNumbers(); - ArrayRef lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); - ArrayRef rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); - ArrayRef lhsContractingDims = - dimNumbers.getLhsContractingDimensions(); - ArrayRef rhsContractingDims = - dimNumbers.getRhsContractingDimensions(); - - // Get shape information and initialize output - assert(lhsContractingDims.size() == rhsContractingDims.size() && - "number of contracting dims must be equal"); - size_t numContracting = lhsContractingDims.size(); - // Convert unsigned to signed. This works because signed and unsigned - // integer matmul is the same operation in two's complement. - auto outputType = - cast(typeConverter->convertType(op.getType())); - size_t targetRank = outputType.getRank(); - size_t totalLoopCount = numContracting + targetRank; - - int64_t lhsRank = - llvm::cast(adaptor.getLhs().getType()).getRank(); - size_t lhsExtraDims = - lhsRank - lhsBatchingDims.size() - lhsContractingDims.size(); - int64_t rhsRank = - llvm::cast(adaptor.getRhs().getType()).getRank(); - - Location loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - SmallVector indexingMaps; - - auto getMap = [&](int64_t rank, ArrayRef batchingDims, - ArrayRef contractingDims, size_t extraDims) { - llvm::SmallVector indices(rank); - for (const auto &i : llvm::enumerate(batchingDims)) { - indices[i.value()] = rewriter.getAffineDimExpr(i.index()); - } - for (const auto &i : llvm::enumerate(contractingDims)) { - indices[i.value()] = rewriter.getAffineDimExpr(i.index() + targetRank); - } - for (int i = 0; i < rank; ++i) { - if (!indices[i]) { - indices[i] = rewriter.getAffineDimExpr(extraDims++); - } - } - indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount, - /*symbolCount=*/0, indices, - op->getContext())); - }; - getMap(lhsRank, lhsBatchingDims, lhsContractingDims, - lhsBatchingDims.size()); - getMap(rhsRank, rhsBatchingDims, rhsContractingDims, - rhsBatchingDims.size() + lhsExtraDims); - - { - SmallVector dimExprs; - dimExprs.reserve(targetRank); - for (unsigned i = 0; i < targetRank; ++i) - dimExprs.push_back(rewriter.getAffineDimExpr(i)); - indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount, - /*symbolCount=*/0, dimExprs, - op.getContext())); - } - - Operation *linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/TypeRange{outputType}, - /*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()}, - /*outputBuffers=*/ValueRange{zeroTensor}, indexingMaps, - getParallelAndReductionIterators( - /*nLoops=*/totalLoopCount, - /*nReduction=*/numContracting), - [](OpBuilder &b, Location loc, ValueRange) { - ImplicitLocOpBuilder builder(loc, b); - linalg::MatmulOp::regionBuilder(builder, *b.getInsertionBlock(), {}); - }, - linalg::getPrunedAttributeList(op)); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; - -} // namespace - -namespace detail { -void populateStableHloDotProdToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - // Ensure specialized patterns are higher priority than their generic - // versions. - patterns - ->add, - DotOpConversion, - DotOpConversion, - DotOpConversion, - DotGeneralBatchMatMulOpConversion>(typeConverter, context, - PatternBenefit(2)); - patterns->add(typeConverter, context, - PatternBenefit(1)); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgPointwise.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgPointwise.cpp deleted file mode 100644 index 6f2985abd50b..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgPointwise.cpp +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO pointwise ops to Linalg dialect. -// These patterns are separated out to their own file to save on the compilation -// times, given that we instantiate a large number of class templates here. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -int64_t getRank(Value v) { return cast(v.getType()).getRank(); } - -int64_t getMaxRank(ValueRange operands) { - int64_t maxRank = 0; - for (Value operand : operands) { - maxRank = std::max(maxRank, getRank(operand)); - } - return maxRank; -} - -bool isScalar(Value v) { return getRank(v) == 0; } - -/// Inserts block arguments in places where scalar inputs have a nullptr. -SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, - ValueRange blockArgs) { - SmallVector result; - auto argsIter = blockArgs.begin(); - for (Value scalarInput : scalarInputs) { - if (scalarInput) { - result.push_back(scalarInput); - } else { - result.push_back(*argsIter); - ++argsIter; - } - } - return result; -} - -struct PointwiseConversionInfo { - int64_t maxOperandRank = 0; - ShapedType resultType; -}; - -/// Checks the preconditions for conversion of pointwise HLO ops to linalg. -/// Returns the max operand rank and the result type on success. -FailureOr -checkOperandsAndResults(Operation *op, ValueRange operands, - const TypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { - int64_t maxRank = getMaxRank(operands); - - // Apply only if all operands are scalar or have the same rank. Some ops, - // like `stablehlo.select`, support implicit broadcasting of scalars. - if (!llvm::all_of(operands, [&](Value v) { - int64_t r = getRank(v); - return r == 0 || r == maxRank; - })) { - return rewriter.notifyMatchFailure( - op, "Operands must be of same rank or scalar."); - } - - // Find result type, if on tensors. - auto resultTy = dyn_cast_or_null( - typeConverter.convertType(op->getResultTypes().front())); - - // Check result type compatibility. - if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || - !(resultTy.getElementType().isSignlessIntOrFloat() || - isa(resultTy.getElementType()))) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); - } - - // All-scalar pointwise ops inside of linalg ops are processes by - // ScalarHloToArithmeticPattern. - if (maxRank == 0 && isInBodyOfLinalgOps(op)) - return failure(); - - return PointwiseConversionInfo{maxRank, resultTy}; -} - -/// Converts a HLO operation to a linalg.map op that contains the corresponding -/// scalar operations. -template -struct PointwiseToLinalgMapConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto conversionInfo = checkOperandsAndResults( - op, adaptor.getOperands(), *this->typeConverter, rewriter); - if (failed(conversionInfo)) { - return failure(); - } - - int64_t maxRank = conversionInfo->maxOperandRank; - ShapedType resultTy = conversionInfo->resultType; - Location loc = op.getLoc(); - - // Find input/output values and types. - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - // Mapped inputs are cast to the same shape as the init tensor. - // Values from scalar inputs are extracted and used directly in the block. - SmallVector mappedInputs; - SmallVector scalarInputs; - for (Value input : adaptor.getOperands()) { - if (getRank(input) == maxRank) { - mappedInputs.push_back(coerceTensorShape( - rewriter, loc, cast>(input), - cast(emptyTensor.getType()))); - scalarInputs.push_back(nullptr); - } else { - scalarInputs.push_back(rewriter.create(loc, input)); - } - } - - auto mapOp = rewriter.create( - loc, mappedInputs, emptyTensor, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, getElementTypeOrSelf(emptyTensor), - interleaveScalarAndBlockArgs(scalarInputs, args), &b); - - b.create(loc, innerResult); - }, - linalg::getPrunedAttributeList(op)); - - rewriter.replaceOp(op, mapOp->getResults()); - return success(); - } -}; - -/// Converts a HLO operation to a linalg.generic op that contains the -/// corresponding scalar operations. -template -struct PointwiseToLinalgConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto conversionInfo = checkOperandsAndResults( - op, adaptor.getOperands(), *this->typeConverter, rewriter); - if (failed(conversionInfo)) { - return failure(); - } - - int64_t maxRank = conversionInfo->maxOperandRank; - ShapedType resultTy = conversionInfo->resultType; - Location loc = op.getLoc(); - - // Find input/output values and types. - ValueRange inputs = adaptor.getOperands(); - Value output = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - // Create indexing maps. - AffineMap scalarMap = AffineMap::get(maxRank, 0, rewriter.getContext()); - AffineMap idMap = rewriter.getMultiDimIdentityMap(maxRank); - SmallVector maps; - for (Value v : inputs) - maps.push_back(isScalar(v) ? scalarMap : idMap); - maps.push_back(idMap); - - // Build `linalg.generic` op. - bool failed = false; - auto linalgOp = rewriter.create( - loc, resultTy ? resultTy : TypeRange{}, inputs, output, maps, - getNParallelLoopsAttrs(maxRank), - [&](OpBuilder &nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - Type innerResultTy = getElementTypeOrSelf(output); - auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); - Value semiring = preSparsify(op, argvec, innerResultTy, &rewriter); - Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, innerResultTy, argvec, &rewriter); - if (!innerResult) { - failed = true; - } else { - innerResult = postSparsify(op, semiring, innerResult, &rewriter); - nestedBuilder.create(loc, innerResult); - } - }, - linalg::getPrunedAttributeList(op)); - if (failed) - return failure(); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; -} // namespace - -namespace detail { -void populatePointwiseStableHloToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps) { - if (enablePrimitiveOps) { - patterns->add< - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter>(typeConverter, - context); - return; - } - - patterns - ->add, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter>(typeConverter, - context); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgRandom.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgRandom.cpp deleted file mode 100644 index e1e73ed75bbc..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgRandom.cpp +++ /dev/null @@ -1,917 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO random number generation to Linalg -// dialect. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -class ArithOpBuilder { -public: - ArithOpBuilder(OpBuilder b, Location l, Value v) - : builder(b), loc(l), value(v) {} - - explicit operator Value() { return value; } - Value val() { return value; } - - ArithOpBuilder constantI(int64_t value, int64_t bits) { - Value val = builder.create( - loc, builder.getIntegerAttr(builder.getIntegerType(bits), value)); - return ArithOpBuilder(builder, loc, val); - } - - ArithOpBuilder extendUI(int32_t bits) { - Value ext = builder.create( - loc, builder.getIntegerType(bits), value); - return ArithOpBuilder(builder, loc, ext); - } - - ArithOpBuilder truncI(int64_t bits) { - if (value.getType().getIntOrFloatBitWidth() == bits) - return *this; - Value trunc = builder.create( - loc, builder.getIntegerType(bits), value); - return ArithOpBuilder(builder, loc, trunc); - } - - ArithOpBuilder linalgIndex(int32_t index) { - Value val = builder.create(loc, index); - return ArithOpBuilder(builder, loc, val); - } - - ArithOpBuilder indexCast(int32_t bitwidth) { - if (isa(value.getType())) { - Value cast = builder.create( - loc, builder.getIndexType(), value); - return ArithOpBuilder(builder, loc, cast); - } - - Value cast = builder.create( - loc, builder.getIntegerType(bitwidth), value); - return ArithOpBuilder(builder, loc, cast); - } - - ArithOpBuilder rotateLeft(int32_t rotation) { - int32_t bits = value.getType().getIntOrFloatBitWidth(); - ArithOpBuilder cLeft = constantI(rotation, bits); - ArithOpBuilder cRight = constantI(bits - rotation, bits); - ArithOpBuilder rLeft = (*this << cLeft); - ArithOpBuilder rRight = (*this >> cRight); - return rLeft | rRight; - } - - ArithOpBuilder operator+(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator*(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator|(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator^(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator<<(ArithOpBuilder &rhs) { - Value shl = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, shl); - } - - ArithOpBuilder operator>>(ArithOpBuilder &rhs) { - Value shr = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, shr); - } - -private: - OpBuilder builder; - Location loc; - Value value; -}; - -std::pair splitI64(ArithOpBuilder i64) { - auto low = i64.truncI(32); - auto c32 = i64.constantI(/*value=*/32, /*bits=*/64); - auto high = (i64 >> c32).truncI(32); - return {low, high}; -} - -ArithOpBuilder fuseI32s(ArithOpBuilder low, ArithOpBuilder high) { - auto c32 = high.constantI(/*value=*/32, /*bits=*/64); - high = high.extendUI(64) << c32; - low = low.extendUI(64); - return low | high; -} - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -std::pair -runThreeFry2xi32(ArithOpBuilder key0, ArithOpBuilder key1, - ArithOpBuilder initialState) { - ArithOpBuilder index = initialState.linalgIndex(0); - index = index.indexCast(64); - index = index + initialState; - - // Split into the 2xi32 used for threefry. - std::pair input = splitI64(index); - ArithOpBuilder input0 = input.first; - ArithOpBuilder input1 = input.second; - - // Magic number and rotation distances specified by the Threefry2x32 - // algorithm. - llvm::SmallVector rotations = {13, 15, 26, 6, 17, 29, 16, 24}; - ArithOpBuilder magic = key0.constantI(/*value=*/0x1bd11bda, /*bits=*/32); - - ArithOpBuilder key2 = magic ^ key0 ^ key1; - std::array ks{key0, key1, key2}; - std::array x{input0 + key0, input1 + key1}; - - // Performs a single round of the Threefry2x32 algorithm, with a rotation - // amount 'rotation'. - for (int i = 0; i < 5; ++i) { - int32_t rot = (4 * i) % rotations.size(); - int32_t k1 = (i + 1) % ks.size(); - int32_t k2 = (i + 2) % ks.size(); - - for (int j = 0; j < 4; ++j) { - x[0] = x[0] + x[1]; - x[1] = x[1].rotateLeft(rotations[rot + j]); - x[1] = x[0] ^ x[1]; - } - - ArithOpBuilder c = x[0].constantI(/*value=*/i + 1, /*bits=*/32); - x[0] = x[0] + ks[k1]; - x[1] = x[1] + ks[k2]; - x[1] = x[1] + c; - } - - return std::pair(x[0], x[1]); -} - -// Extract and potentially reconstruct the i32 key-pair as necessary. -std::pair extractKey32(OpBuilder &builder, Location loc, - Value store) { - auto storeTy = cast(store.getType()); - if (storeTy.getRank() != 1) - return {nullptr, nullptr}; - - Type storeETy = storeTy.getElementType(); - IntegerType i32Ty = builder.getIntegerType(32); - IntegerType i64Ty = builder.getIntegerType(64); - - if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { - Value idx0 = builder.create(loc, 0); - Value idx1 = builder.create(loc, 1); - Value key0 = builder.create(loc, store, idx0); - Value key1 = builder.create(loc, store, idx1); - key0 = builder.create(loc, i32Ty, key0); - key1 = builder.create(loc, i32Ty, key1); - return {key0, key1}; - } - - if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 0); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - auto pair = splitI64(ArithOpBuilder(builder, loc, cast)); - return std::pair(pair.first, pair.second); - } - - // TODO(#14859): Properly handle 128-bit storage keys. - if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 0); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - auto pair = splitI64(ArithOpBuilder(builder, loc, cast)); - return std::pair(pair.first, pair.second); - } - - return {nullptr, nullptr}; -} - -// Extract and potentially reconstruct the i64 state as necessary. -Value extractState64(OpBuilder &builder, Location loc, Value store) { - auto storeTy = cast(store.getType()); - if (storeTy.getRank() != 1) - return nullptr; - - Type storeETy = storeTy.getElementType(); - IntegerType i64Ty = builder.getIntegerType(64); - - if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 1); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - return cast; - } - - // TODO(#14859): Properly handle 128-bit storage keys. - if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 1); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - return cast; - } - - if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { - Value idx2 = builder.create(loc, 2); - Value idx3 = builder.create(loc, 3); - - Value low = builder.create(loc, store, idx2); - Value high = builder.create(loc, store, idx3); - - ArithOpBuilder i64 = fuseI32s(ArithOpBuilder(builder, loc, high), - ArithOpBuilder(builder, loc, low)); - return builder.create(loc, i64Ty, i64.val()); - } - - return nullptr; -} - -Value setState64(OpBuilder &b, Location loc, Value store, Value state) { - auto storeTy = cast(store.getType()); - if (storeTy.getRank() != 1) - return nullptr; - - Type storeETy = storeTy.getElementType(); - - if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { - state = b.create(loc, storeETy, state); - Value idx1 = b.create(loc, 1); - return b.create(loc, storeTy, state, store, - ValueRange{idx1}); - } - - // TODO(#14859): Properly handle 128-bit storage keys. - if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { - state = b.create(loc, storeETy, state); - Value idx1 = b.create(loc, 1); - return b.create(loc, storeTy, state, store, - ValueRange{idx1}); - } - - if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { - Value idx2 = b.create(loc, 2); - Value idx3 = b.create(loc, 3); - std::pair states = - splitI64(ArithOpBuilder(b, loc, state)); - Value state0 = - b.create(loc, storeETy, states.first.val()); - Value state1 = - b.create(loc, storeETy, states.second.val()); - Value insert0 = b.create(loc, storeTy, state0, store, - ValueRange{idx2}); - Value insert1 = b.create(loc, storeTy, state1, insert0, - ValueRange{idx3}); - return insert1; - } - - return nullptr; -} - -Value reshapeToTarget(OpBuilder &builder, Location loc, ShapedType destTy, - Value src) { - auto srcTy = cast(src.getType()); - // Expand out to the target shape. - - auto reassociationIndices = - getReassociationIndicesForCollapse(destTy.getShape(), srcTy.getShape()); - if (reassociationIndices.has_value()) { - src = builder.create(loc, destTy, src, - reassociationIndices.value()); - } - - // It is also possible our target is Rank-0, then we would - // need to collapse. - reassociationIndices = - getReassociationIndicesForCollapse(srcTy.getShape(), destTy.getShape()); - if (reassociationIndices.has_value()) { - src = builder.create(loc, destTy, src, - reassociationIndices.value()); - } - - return src; -} - -// Compute the shape for computing three fry. -std::pair threeFry32Shape(ShapedType resultTy) { - if (resultTy.getRank() == 0) { - return {resultTy, 0}; - } - - ArrayRef shape = resultTy.getShape(); - uint64_t halfDim = - std::max_element(shape.begin(), shape.end()) - shape.begin(); - - for (int i = 0, s = shape.size(); i < s; i++) { - if (shape[i] & 0x1) - continue; - halfDim = i; - break; - } - - llvm::SmallVector newShape(shape); - newShape[halfDim] = (newShape[halfDim] + 1) / 2; - if (halfDim == (newShape.size() - 1)) { - newShape.push_back(1); - } - - return {RankedTensorType::get(newShape, resultTy.getElementType()), halfDim}; -} - -/// This implementation generates a 32-bit tensor of ThreeFry random numbers. -/// It matches the XLA implementation bit-exact and includes an inefficient -/// method of concatenating / slicing the pairs of generated numbers. -/// -/// We should consider dropping the complex slicing and simply generating -/// 2x the values, then downcast to a 32-bit. It substantially simplifies -/// the computation and avoids the concat / slice behavior. -LogicalResult generateLinalgThreeFry32(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - - // Extract the stateful values as an i64 and increment the state ahead. - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - ArithOpBuilder key0(builder, loc, keys.first); - ArithOpBuilder key1(builder, loc, keys.second); - - // Compute the intermediate type we use to compute three fry values, including - // the dimension that was halved. - auto pair = threeFry32Shape(resultTy); - ShapedType intermediateType = pair.first; - int64_t halfDim = pair.second; - int64_t count = intermediateType.getNumElements(); - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // Generate a 1D tensor with for the random values. - Value destLeft = builder.create( - loc, ArrayRef({count}), resultETy); - Value destRight = builder.create( - loc, ArrayRef({count}), resultETy); - - ShapedType destTy = llvm::cast(destLeft.getType()); - - SmallVector indexingMaps(2, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - linalg::GenericOp generic = builder.create( - loc, TypeRange{destTy, destTy}, - /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{destLeft, destRight}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - // Grab three fry results and write to each array. - auto split = runThreeFry2xi32( - key0, key1, ArithOpBuilder(b, nestedLoc, initialState)); - auto first = split.first.truncI(resultETy.getIntOrFloatBitWidth()); - auto second = split.second.truncI(resultETy.getIntOrFloatBitWidth()); - b.create(loc, ValueRange{first.val(), second.val()}); - }); - - if (resultTy.getNumElements() == 1) { - result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); - store = setState64(builder, loc, store, newState); - return success(); - } - - // Reshape to the target size and concatenate on the dimension following the - // half dimension. - Value random0 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); - Value random1 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); - Value concatenate = builder.create( - loc, ValueRange{random0, random1}, - builder.getI64IntegerAttr(halfDim + 1)); - - // Collapse the concat dimension back into the parent. - llvm::SmallVector collapseShape(resultTy.getShape()); - collapseShape[halfDim] = - collapseShape[halfDim] + (collapseShape[halfDim] & 1); - Value reshape = builder.create( - loc, resultTy.clone(collapseShape), concatenate); - - // Slice to only the required results. - llvm::SmallVector offset(resultTy.getRank(), 0); - llvm::SmallVector stride(resultTy.getRank(), 1); - Value slice = builder.create( - loc, resultTy, reshape, builder.getDenseI64ArrayAttr(offset), - builder.getDenseI64ArrayAttr(resultTy.getShape()), - builder.getDenseI64ArrayAttr(stride)); - - // Set the new tensor values. - store = setState64(builder, loc, store, newState); - result = slice; - - return success(); -} - -LogicalResult generateLinalgThreeFry64(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - int64_t count = resultTy.getNumElements(); - - // Extract the stateful values as an i64 and increment the state ahead. - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - ArithOpBuilder key0(builder, loc, keys.first); - ArithOpBuilder key1(builder, loc, keys.second); - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // Generate a 1D tensor with for the random values. - Value dest = builder.create(loc, ArrayRef({count}), - resultETy); - ShapedType destTy = llvm::cast(dest.getType()); - - SmallVector indexingMaps(1, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - auto random = builder.create( - loc, destTy, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{dest}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - // Generate three fry results, fuse, and return an - // i64. - auto split = runThreeFry2xi32( - key0, key1, ArithOpBuilder(b, nestedLoc, initialState)); - Value result = fuseI32s(split.first, split.second).val(); - b.create(nestedLoc, result); - }); - - store = setState64(builder, loc, store, newState); - result = reshapeToTarget(builder, loc, resultTy, random.getResult(0)); - return success(); -} - -using PhiloxKey = std::pair; -using PhiloxState = std::array; - -// Computes high and low words from multiplying 32 bit integers. -// Per the paper, mulhi and mullo of the same arguments can be computed -// Simultaneously in a single instruction on x86 architectures. -std::pair multiplyHilo(ArithOpBuilder counter, - ArithOpBuilder key) { - counter = counter.extendUI(64); - key = key.extendUI(64); - ArithOpBuilder product = counter * key; - ArithOpBuilder ci64 = counter.constantI(/*value=*/32, /*bits=*/64); - ArithOpBuilder hi = product >> ci64; - hi = hi.truncI(32); - product = product.truncI(32); - return std::pair{hi, product}; -} - -PhiloxState philoxRound(PhiloxState x, PhiloxKey key) { - // These are philox specific constants. - ArithOpBuilder m0 = x[0].constantI(0xD2511F53, 32); - ArithOpBuilder m1 = x[2].constantI(0xCD9E8D57, 32); - std::pair p0 = multiplyHilo(x[0], m0); - std::pair p1 = multiplyHilo(x[2], m1); - - PhiloxState state = {p1.first ^ x[1] ^ key.first, p1.second, - p0.first ^ x[3] ^ key.second, p0.second}; - return state; -} - -PhiloxKey raiseKey(PhiloxKey key) { - // These are philox specific constants. - ArithOpBuilder w0 = key.first.constantI(0x9E3779B9, 32); - ArithOpBuilder w1 = key.first.constantI(0xBB67AE85, 32); - return PhiloxKey{key.first + w0, key.second + w1}; -} - -// Implements the Philox 4x32 counter-based PRNG algorithm. -// The Philox PRNG has been proposed in: -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -std::array runPhilox4x32(PhiloxKey key, - ArithOpBuilder state) { - ArithOpBuilder index = state.linalgIndex(0); - index = index.indexCast(64); - index = index + state; - - // Split into the 2xi32 used for threefry. - std::pair input = splitI64(index); - ArithOpBuilder input0 = input.first; - ArithOpBuilder input1 = input.second; - - // We initialize the state as such to match the XLA implementation. - PhiloxState state4 = {input0, input1, key.first, key.second}; - - // We perform 10 rounds to match the XLA implementation. - constexpr int kNumRounds = 10; - for (int round = 0; round < kNumRounds; ++round, key = raiseKey(key)) { - state4 = philoxRound(state4, key); - } - return state4; -} - -// Generates an array of primitive type U32 with the given shape containing -// random bits generated by the Philox algorithm. Returns the array and the new -// state of the random number generator. -LogicalResult generateLinalgPhilox32(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - int64_t numElements = resultTy.getNumElements(); - int64_t count = (numElements + 3) / 4; - ShapedType intermediateType = - RankedTensorType::get({count, 1}, resultTy.getElementType()); - int64_t concatDim = 1; - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // set up four outputs - Value dest0 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest1 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest2 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest3 = builder.create(loc, ArrayRef({count}), - resultETy); - - ShapedType destTy = cast(dest0.getType()); - - SmallVector indexingMaps(4, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - linalg::GenericOp generic = builder.create( - loc, TypeRange{destTy, destTy, destTy, destTy}, - /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{dest0, dest1, dest2, dest3}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - auto output = - runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first), - ArithOpBuilder(b, nestedLoc, keys.second)}, - ArithOpBuilder(b, nestedLoc, initialState)); - auto out0 = output[0].truncI(resultETy.getIntOrFloatBitWidth()); - auto out1 = output[1].truncI(resultETy.getIntOrFloatBitWidth()); - auto out2 = output[2].truncI(resultETy.getIntOrFloatBitWidth()); - auto out3 = output[3].truncI(resultETy.getIntOrFloatBitWidth()); - b.create( - loc, ValueRange{out0.val(), out1.val(), out2.val(), out3.val()}); - }); - - if (resultTy.getNumElements() == 1) { - result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); - store = setState64(builder, loc, store, newState); - return success(); - } - - Value r0 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); - Value r1 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); - Value r2 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(2)); - Value r3 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(3)); - - Value concatenate = builder.create( - loc, ValueRange{r0, r1, r2, r3}, builder.getI64IntegerAttr(concatDim)); - - // Collapse the concat dimension back into the parent. - llvm::SmallVector collapseShape(intermediateType.getShape()); - collapseShape[0] = collapseShape[0] * 4; - Value reshapeIntermediate = builder.create( - loc, resultTy.clone(collapseShape), concatenate); - - // Slice to only the required results. - collapseShape[0] = resultTy.getNumElements(); - - auto sliceResultTy = intermediateType.clone(collapseShape); - llvm::SmallVector offset(sliceResultTy.getRank(), 0); - llvm::SmallVector stride(sliceResultTy.getRank(), 1); - Value slice = builder.create( - loc, sliceResultTy, reshapeIntermediate, - builder.getDenseI64ArrayAttr(offset), - builder.getDenseI64ArrayAttr(collapseShape), - builder.getDenseI64ArrayAttr(stride)); - Value reshapeResult = - builder.create(loc, resultTy, slice); - - // Set the new tensor values. - store = setState64(builder, loc, store, newState); - result = reshapeResult; - - return success(); -} - -LogicalResult generateLinalgPhilox64(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - int64_t numElements = resultTy.getNumElements(); - int64_t count = (numElements + 1) / 2; - ShapedType intermediateType = - RankedTensorType::get({count, 1}, resultTy.getElementType()); - int64_t concatDim = 1; - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // set up four outputs - Value dest0 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest1 = builder.create(loc, ArrayRef({count}), - resultETy); - ShapedType destTy = cast(dest0.getType()); - - SmallVector indexingMaps(2, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - linalg::GenericOp generic = builder.create( - loc, TypeRange{destTy, destTy}, - /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{dest0, dest1}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - auto output = - runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first), - ArithOpBuilder(b, nestedLoc, keys.second)}, - ArithOpBuilder(b, nestedLoc, initialState)); - auto out0 = output[0]; - auto out1 = output[1]; - auto out2 = output[2]; - auto out3 = output[3]; - Value result1 = fuseI32s(out0, out1).val(); - Value result2 = fuseI32s(out2, out3).val(); - b.create(loc, ValueRange{result1, result2}); - }); - - if (resultTy.getNumElements() == 1) { - result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); - store = setState64(builder, loc, store, newState); - return success(); - } - - Value r0 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); - Value r1 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); - Value concatenate = builder.create( - loc, ValueRange{r0, r1}, builder.getI64IntegerAttr(concatDim)); - - // Collapse the concat dimension back into the parent. - llvm::SmallVector collapseShape(intermediateType.getShape()); - collapseShape[0] = collapseShape[0] * 2; - Value reshapeIntermediate = builder.create( - loc, resultTy.clone(collapseShape), concatenate); - - // Slice to only the required results. - collapseShape[0] = resultTy.getNumElements(); - - auto sliceResultTy = intermediateType.clone(collapseShape); - llvm::SmallVector offset(sliceResultTy.getRank(), 0); - llvm::SmallVector stride(sliceResultTy.getRank(), 1); - Value slice = builder.create( - loc, sliceResultTy, reshapeIntermediate, - builder.getDenseI64ArrayAttr(offset), - builder.getDenseI64ArrayAttr(collapseShape), - builder.getDenseI64ArrayAttr(stride)); - Value reshapeResult = - builder.create(loc, resultTy, slice); - - // Set the new tensor values. - store = setState64(builder, loc, store, newState); - result = reshapeResult; - - return success(); -} - -LogicalResult generateLinalgThreeFry(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &state, - Value &result) { - Type eTy = resultTy.getElementType(); - unsigned bitwidth = eTy.getIntOrFloatBitWidth(); - - if (bitwidth == 64) { - return generateLinalgThreeFry64(builder, loc, resultTy, state, result); - } - if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) { - return generateLinalgThreeFry32(builder, loc, resultTy, state, result); - } - - return failure(); -} - -LogicalResult generateLinalgPhilox(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &state, - Value &result) { - Type eTy = resultTy.getElementType(); - unsigned bitwidth = eTy.getIntOrFloatBitWidth(); - if (bitwidth == 64) { - return generateLinalgPhilox64(builder, loc, resultTy, state, result); - } - - // The 32 bit implementation trancates to result eTy. - if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) { - return generateLinalgPhilox32(builder, loc, resultTy, state, result); - } - - return failure(); -} - -struct RngBitGeneratorConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::RngBitGeneratorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value state = adaptor.getInitialState(); - auto resultTy = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult(1).getType())); - if (!resultTy) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::THREE_FRY) { - Value random; - if (failed( - generateLinalgThreeFry(rewriter, loc, resultTy, state, random))) { - return failure(); - } - rewriter.replaceOp(op, {state, random}); - return success(); - } - - if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::PHILOX || - op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::DEFAULT) { - Value random; - if (failed( - generateLinalgPhilox(rewriter, loc, resultTy, state, random))) { - return failure(); - } - rewriter.replaceOp(op, {state, random}); - return success(); - } - - return failure(); - } -}; - -struct RngUniformConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::RngOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // We only handle uniform distributions. - if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::UNIFORM) { - return failure(); - } - // TODO(raikonenfnu): Handle other element types as well. - auto minTy = dyn_cast(adaptor.getA().getType()); - auto maxTy = dyn_cast(adaptor.getB().getType()); - if (!isa(minTy.getElementType()) || - !isa(maxTy.getElementType())) { - return rewriter.notifyMatchFailure( - op, "expected min/max for rng op to be FloatType"); - } - auto targetTy = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!targetTy) { - return rewriter.notifyMatchFailure( - op, "expected target shape of rng op to be ShapedType"); - } - auto loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, targetTy, op, adaptor.getOperands()); - // Creates index map using target matrix's rank. - auto targetRank = targetTy.getRank(); - SmallVector indexingMaps( - 2, AffineMap::get(targetRank, /*symbolCount=*/0, - SmallVector({}), rewriter.getContext())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(targetRank)); - const int kInitialSeed = 0; - - // Generic region with LCG Algorithm that make use of element index from: - // https://reviews.llvm.org/D101364 - auto linalgOp = rewriter.create( - loc, /*resultTensors=*/targetTy, - /*inputs=*/ - ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]}, - /*outputs=*/emptyTensor, indexingMaps, - getParallelAndReductionIterators(/*nLoops=*/targetRank, - /*nReduction=*/0), - [&](OpBuilder &b, Location loc, ValueRange args) { - llvm::SmallVector updateVec = {b.create( - loc, b.getI32IntegerAttr(kInitialSeed))}; - Value multiplier = - b.create(loc, b.getI32IntegerAttr(1103515245)); - Value incrementStep = - b.create(loc, b.getI32IntegerAttr(12345)); - // For output matrix with rank N: - // temp1 = (cast(I32, index(D.0)) + seed) * mult + incr - // ... - // tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr - for (int i = 0; i < targetRank; i++) { - Value update = updateVec.back(); - Value ind = b.create(loc, i); - Value castInd = - b.create(loc, b.getI32Type(), ind); - Value addRes = b.create(loc, castInd, update); - Value multRes = b.create(loc, addRes, multiplier); - Value incRes = b.create(loc, multRes, incrementStep); - updateVec.push_back(incRes); - } - // Scaling = (max - min) * const(F64, 2.3283064E-10) - // which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)). - Value epsilon = b.create( - loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10)); - Value range = b.create(loc, args[1], args[0]); - Value scale = b.create(loc, range, epsilon); - // Res = cast(T, cast(F64, tempN) * scaling + min) - Value updateCast = b.create( - loc, targetTy.getElementType(), updateVec.back()); - Value scaleUpdate = b.create(loc, updateCast, scale); - Value res = b.create(loc, scaleUpdate, args[0]); - b.create(loc, res); - }, - linalg::getPrunedAttributeList(op)); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; -} // namespace - -namespace detail { -void populateStableHloRandomToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - patterns->add(typeConverter, - context); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp deleted file mode 100644 index ce8ba53ac83c..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp +++ /dev/null @@ -1,723 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO reduction ops to Linalg dialect. -// These patterns are separated out to their own file to save on the compilation -// times. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -/// Returns true when reduction `op` is not supported and should be filtered -/// out. -static bool isUnsupported(mlir::stablehlo::ReduceOp op) { - // Empty reductions are not supported. We expect canonicalization patterns to - // handle them. - if (op.getDimensions().empty()) - return true; - - // We require all reduce shapes to be the same, up to the element types, so - // we can just the first operand and the first result as a representative. - if (auto inputTy = - dyn_cast(op.getInputs().getType().front())) { - return llvm::is_contained(inputTy.getShape(), 0); - } - - return false; -} - -/// Returns a permutation AffineMap that puts all reduction dimensions to the -/// last. The order of parallel loops and reduction loops are all sorted. E.g., -/// if `rank` is 4 and `reductionDims` is {1, 3}, then -/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of -/// the AffineMap is returned. -AffineMap getTransposeMapForReduction(MLIRContext *context, int rank, - ArrayRef reductionDims) { - llvm::SmallSetVector s(reductionDims.begin(), reductionDims.end()); - - SmallVector permutation; - for (int i = 0; i < rank; ++i) { - if (!s.contains(i)) { - permutation.push_back(i); - } - } - - llvm::append_range(permutation, reductionDims); - auto map = AffineMap::getPermutationMap(permutation, context); - return inversePermutation(map); -} - -SmallVector -getReduceOpEmptyTensorDynSizes(OpBuilder &b, Location loc, Value arg, - ShapedType resultType, - ArrayRef reductionDims) { - llvm::SmallSetVector s(reductionDims.begin(), reductionDims.end()); - - SmallVector parallelDims; - SmallVector dynShape; - int rank = cast(arg.getType()).getRank(); - for (int i = 0, j = 0; i < rank; ++i) { - if (s.contains(i)) - continue; - if (!resultType.isDynamicDim(j++)) - continue; - dynShape.push_back(b.create(loc, arg, i)); - } - - return dynShape; -} - -struct ReduceRegionReturnOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isInBodyOfLinalgOps(op)) { - return failure(); - } - - SmallVector operands(adaptor.getOperands()); - for (Value &operand : operands) { - if (isa(operand.getType())) { - Location loc = operand.getLoc(); - operand = rewriter.create(loc, operand); - } - } - rewriter.replaceOpWithNewOp(op, operands); - return success(); - } -}; - -struct ReduceOpToGenericConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (isUnsupported(op)) { - return rewriter.notifyMatchFailure(op, - "unsupported reduce (noop or empty)"); - } - - Location loc = op.getLoc(); - - int numOperands = static_cast(adaptor.getInputs().size()); - - if (llvm::any_of(adaptor.getInputs(), [](Value v) { - return !isa(v.getType()); - })) { - return rewriter.notifyMatchFailure(op, "expects known-rank args"); - } - auto srcRank = cast(adaptor.getInputs()[0].getType()).getRank(); - - SmallVector reductionDims = llvm::to_vector(op.getDimensions()); - - SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - - SmallVector outputs; - SmallVector indexingMaps; - for (auto [operand, initValue, resultType] : llvm::zip_equal( - adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) { - // Check if init_value is constant. If so, inline the value into the - // region. - initValue = rewriter.createOrFold(loc, initValue); - - SmallVector dynShape = getReduceOpEmptyTensorDynSizes( - rewriter, loc, operand, cast(resultType), reductionDims); - auto emptyTensor = - getEmptyTensor(rewriter, loc, cast(resultType), dynShape); - Value filledTensor = - rewriter.create(loc, initValue, emptyTensor).result(); - outputs.push_back(filledTensor); - } - - // Prepare indexing maps for linalg generic op. The elements are for src - // and dst. Transpose `src` to make the reduction loops be the innermost, - // because it's easier to fully utilize processors. - indexingMaps.append(numOperands, - getTransposeMapForReduction(rewriter.getContext(), - static_cast(srcRank), - reductionDims)); - - // The indexing map of `dst` should drop the reduction loops. Since the - // reduction loops now are all in the innermost, drops - // `reduction_dims.size()` dimensions. We don't need an inverse - // permutation here because they are the same. - SmallVector exprs; - for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i) { - exprs.push_back(rewriter.getAffineDimExpr(i)); - } - indexingMaps.append(numOperands, - AffineMap::get(srcRank, /*symbolCount=*/0, exprs, - rewriter.getContext())); - - auto linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/resultTypes, adaptor.getInputs(), - /*outputBuffers=*/ValueRange{outputs}, indexingMaps, - getParallelAndReductionIterators(srcRank, reductionDims.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. The reduce op region apply function - // has a signature (lhs, rhs) -> output, all of the same tensor type t. - // This is converted to a function with the same signature but with - // element types. E.g., "(tensor, tensor) -> tensor" will - // be converted to "(f32, f32, f32)". - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getBody(), region, region.end()); - TypeConverter::SignatureConversion signatureConverter(numOperands * 2); - - // The stablehlo ReduceOp requires that the seed be used as a LHS operand - // inside the region, and the seed is encoded in linalg in the initial out - // value, so modify the signature of the block and the value mappings, so - // the output args will correlate with the original LHS and the inputs - // correlate with the original RHS. - for (auto [idx, val] : llvm::enumerate(op.getInputs())) { - signatureConverter.addInputs( - /*origInputNo=*/idx + numOperands, - // type for the new operand number 'idx'. - typeConverter->convertType( - cast(val.getType()).getElementType())); - } - for (auto [idx, val] : llvm::enumerate(op.getInitValues())) { - signatureConverter.addInputs( - /*origInputNo=*/idx, - // type for the new operand number 'idx' + 'numOperands'. - typeConverter->convertType( - cast(val.getType()).getElementType())); - } - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct ReduceOpToReduceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (isUnsupported(op)) { - return rewriter.notifyMatchFailure(op, - "unsupported reduce (noop or empty)"); - } - - auto reductionDims = llvm::to_vector(op.getDimensions()); - // stablehlo.reduce doesn't specify the order of the reduction dimensions. - llvm::sort(reductionDims); - - auto toRankedTensor = [](Value v) -> RankedTensorType { - return dyn_cast(v.getType()); - }; - - SmallVector outputs; - SmallVector operandTypes, initTypes; - SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - - Location loc = op.getLoc(); - for (auto [operand, initValue, resultType] : llvm::zip_equal( - adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) { - auto initType = toRankedTensor(initValue); - if (!initType) - return rewriter.notifyMatchFailure(op, - "expects known-rank init values"); - initTypes.push_back(initType); - auto operandType = toRankedTensor(operand); - if (!operandType) - return rewriter.notifyMatchFailure(op, "expects known-rank operands"); - operandTypes.push_back(operandType); - initValue = rewriter.createOrFold(loc, initValue); - auto tensorResultType = cast(resultType); - // For linalg.reduce, the result type's dimensions must match the input's - // dimensions, whereas StableHLO allows replacing static dimensions with - // dynamic ones. - SmallVector resultShape; - SmallVector dynShape; - for (auto [index, dim] : - llvm::enumerate(cast(operand.getType()).getShape())) { - if (!llvm::is_contained(reductionDims, index)) { - resultShape.push_back(dim); - if (ShapedType::isDynamic(dim)) { - dynShape.push_back( - rewriter.create(loc, operand, index)); - } - } - } - - Value emptyTensor = rewriter.create( - loc, resultShape, tensorResultType.getElementType(), dynShape); - Value filledTensor = - rewriter.create(loc, initValue, emptyTensor).result(); - outputs.push_back(filledTensor); - } - - auto linalgOp = rewriter.create( - loc, adaptor.getInputs(), outputs, reductionDims, - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getBody(), region, region.end()); - - // Convert the signature of the body. The reduce op 'computation' region - // apply function has a signature with tensor types, this is converted to a - // function with element types. E.g. the signature "(tensor, - // tensor) -> tensor" will be converted to "(f32, f32) -> f32". - // Also, we need to swap the operands of the function. The stablehlo.reduce - // op expects the init values to be the first parameters of the apply - // function, while the linalg.reduction op expects the init values as the - // last parameters of the 'combiner' region apply function. - TypeConverter::SignatureConversion signatureConverter( - linalgOp.getNumDpsInputs() * 2); - assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits()); - for (const auto &[idx, val] : llvm::enumerate(operandTypes)) { - signatureConverter.addInputs( - /*origInputNo=*/idx + linalgOp.getNumDpsInputs(), - // type for new operand number 'idx'. - typeConverter->convertType(val.getElementType())); - } - for (const auto &[idx, val] : llvm::enumerate(initTypes)) { - signatureConverter.addInputs( - /*origInputNo=*/idx, - // type for new operand number 'idx' + linalgOp.getNumInputs() - typeConverter->convertType(val.getElementType())); - } - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - - // Cast the result to the correct type. - SmallVector results; - for (auto [result, resultType] : - llvm::zip(linalgOp.getResults(), resultTypes)) { - results.push_back( - rewriter.createOrFold(loc, resultType, result)); - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ReduceWindowOpOnTensorsGenericConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = op->getContext(); - Location loc = op.getLoc(); - llvm::SmallVector initValues = adaptor.getInitValues(); - llvm::SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - auto numOperands = initValues.size(); - - llvm::SmallVector windowDimensions(op.getWindowDimensions()); - - llvm::SmallVector padding; - if (op.getPadding()) { - padding = extract1DVector(*op.getPadding()); - } - - llvm::SmallVector baseDilations; - if (op.getBaseDilations()) { - baseDilations = llvm::to_vector(*op.getBaseDilations()); - } - - llvm::SmallVector windowStrides(windowDimensions.size(), 1); - if (op.getWindowStrides()) { - windowStrides = llvm::to_vector(*op.getWindowStrides()); - } - - llvm::SmallVector windowDilations(windowDimensions.size(), 1); - if (op.getWindowDilations()) { - windowDilations = llvm::to_vector(*op.getWindowDilations()); - } - - auto rank = static_cast(windowDimensions.size()); - SmallVector srcExprs; - SmallVector windowExprs; - SmallVector dstExprs; - SmallVector filteredWindowDims; - - int windowDim = 0; - for (int64_t i = 0; i < rank; i++) { - AffineExpr srcExpr = mlir::getAffineDimExpr(i, ctx); - - if (windowStrides[i] != 1) - srcExpr = srcExpr * windowStrides[i]; - - if (windowDimensions[i] != 1) { - filteredWindowDims.push_back(windowDimensions[i]); - AffineExpr windowExpr = mlir::getAffineDimExpr(rank + windowDim, ctx); - windowExprs.push_back(windowExpr); - - if (windowDilations[i] != 1) - windowExpr = windowExpr * windowDilations[i]; - - srcExpr = srcExpr + windowExpr; - windowDim++; - } - - srcExprs.push_back(srcExpr); - dstExprs.push_back(mlir::getAffineDimExpr(i, ctx)); - } - - SmallVector inferredMaps(3, AffineMap::get(ctx)); - if (rank > 0) { - inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); - } - - SmallVector indexingMaps; - - indexingMaps.append(numOperands, inferredMaps[0]); - indexingMaps.append(1, inferredMaps[1]); - indexingMaps.append(numOperands, inferredMaps[2]); - - // Setup the initial values. - llvm::SmallVector broadcastValues; - for (uint64_t i = 0, s = initValues.size(); i < s; i++) { - Value initValue = initValues[i]; - auto resultTy = llvm::cast(resultTypes[i]); - if (!resultTy.hasStaticShape()) - return failure(); - - auto broadcastSizes = rewriter.getDenseI64ArrayAttr(resultTy.getShape()); - broadcastValues.push_back(rewriter.create( - loc, resultTy, initValue, broadcastSizes)); - } - - llvm::SmallVector inputs = llvm::to_vector(adaptor.getInputs()); - - // Pad as necessary. - if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) || - llvm::any_of(baseDilations, [](int64_t v) { return v != 1; })) { - llvm::SmallVector staticLows(rank, 0); - llvm::SmallVector staticHighs(rank, 0); - for (int64_t i = 0; i < static_cast(padding.size()); i += 2) { - staticLows[i / 2] = padding[i]; - staticHighs[i / 2] = padding[i + 1]; - } - // Translate base dilation into interior padding. - llvm::SmallVector staticInteriors(rank, 0); - for (auto [idx, dilation] : llvm::enumerate(baseDilations)) { - staticInteriors[idx] = dilation - 1; - } - - for (auto [input, initValue] : llvm::zip(inputs, initValues)) { - input = rewriter.create( - loc, input, initValue, staticLows, staticHighs, staticInteriors); - } - } - - // Add the extra input for the reduction dimension. - inputs.push_back(rewriter.create(loc, filteredWindowDims, - rewriter.getF32Type())); - - auto linalgOp = rewriter.create( - loc, /*resultTensors=*/resultTypes, - /*inputs=*/inputs, - /*outputs=*/broadcastValues, indexingMaps, - getParallelAndReductionIterators(rank + filteredWindowDims.size(), - filteredWindowDims.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. This includes converting scalar - // tensors to their scalar values and inserting an additional block arg for - // the window arg. - Region ®ion = linalgOp.getRegion(); - rewriter.cloneRegionBefore(op.getBody(), region, region.end()); - - TypeConverter::SignatureConversion signatureConverter( - inputs.size() + op->getNumResults() - 1); - - // ReduceWindow requires that the seed be used as a LHS operand inside the - // region, and the seed is encoded in linalg in the initial out value, so - // modify the signature of the block and the value mappings, so the output - // args will correlate with the LHS and the inputs correlate with the RHS. - for (auto [i, type] : llvm::enumerate(resultTypes)) { - auto idx = inputs.size() + i - 1; - signatureConverter.addInputs(idx, - cast(type).getElementType()); - } - - signatureConverter.addInputs( - cast(inputs.back().getType()).getElementType()); - - for (auto [i, input] : - llvm::enumerate(ArrayRef(inputs).drop_back())) { - signatureConverter.addInputs( - i, cast(input.getType()).getElementType()); - } - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct ReduceWindowOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - /// Get the operation used for reduction applied to `result_index`th result. - /// Its expected to be a binary operation that consumes `result_index`th and - /// `result_index + getInputs().size`th arguments of the body. - static Operation *getReductionOp(mlir::stablehlo::ReduceWindowOp op, - int resultIndex) { - auto returnOp = - cast(op.getBody().front().getTerminator()); - Operation *computeOp = returnOp.getResults()[resultIndex].getDefiningOp(); - if (computeOp->getNumOperands() != 2) - return nullptr; - auto arg0 = llvm::dyn_cast(computeOp->getOperand(0)); - auto arg1 = llvm::dyn_cast(computeOp->getOperand(1)); - if (!arg0 || !arg1) - return nullptr; - int64_t arg0Num = arg0.getArgNumber(); - int64_t arg1Num = arg1.getArgNumber(); - int64_t otherArgIndex = resultIndex + op.getInputs().size(); - if (arg0Num == resultIndex && arg1Num == otherArgIndex) - return computeOp; - if (arg0Num == otherArgIndex && arg1Num == resultIndex && - computeOp->hasTrait()) - return computeOp; - return nullptr; - } - - /// stablehlo.reduce_window is mapped to a linalg.pooling operation. The type - /// of the pooling is determined based on the body of the reduce window - /// operation. This class enumerates the different variants. - enum class PoolingType { - kInvalid, - k2DMin, - k3DMin, - k2DMax, - k3DMax, - k2DAdd, - k3DAdd, - }; - - static PoolingType getPoolingType(mlir::stablehlo::ReduceWindowOp reduceOp, - int resultIndex) { - auto rank = llvm::cast(reduceOp.getResultTypes()[resultIndex]) - .getRank(); - if (Operation *op = getReductionOp(reduceOp, resultIndex)) { - if (isa(*op) && rank == 4) - return PoolingType::k2DMin; - if (isa(*op) && rank == 5) - return PoolingType::k3DMin; - if (isa(*op) && rank == 4) - return PoolingType::k2DMax; - if (isa(*op) && rank == 5) - return PoolingType::k3DMax; - if (isa(*op) && rank == 4) - return PoolingType::k2DAdd; - if (isa(*op) && rank == 5) - return PoolingType::k3DAdd; - } - return PoolingType::kInvalid; - } - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - int rank = llvm::cast(op.getResultTypes()[0]).getRank(); - if (rank != 4 && rank != 5) { - return rewriter.notifyMatchFailure( - op, "expected NHWC/NDHWC pooling-based op"); - } - - if (op.getPadding() && !isSplatValue(*op.getPadding(), 0)) { - return rewriter.notifyMatchFailure(op, "require paddings are all zero"); - } - - if (op.getBaseDilations() && !isSplatValue(*op.getBaseDilations(), 1)) { - return rewriter.notifyMatchFailure(op, "expected undilated base"); - } - - int lastDim = rank - 1; - SmallVector fakeWindowShapes; - for (int i = 1; i < lastDim; ++i) { - fakeWindowShapes.push_back(op.getWindowDimensions()[i]); - } - - if (op.getWindowStrides() && - (op.getWindowStrides().value()[0] != 1 || - op.getWindowStrides().value()[lastDim] != 1)) { - return rewriter.notifyMatchFailure( - op, "expected window_strides to be [1,x,y,(z),1]"); - } - if (op.getWindowDimensions()[0] != 1 || - op.getWindowDimensions()[lastDim] != 1) { - return rewriter.notifyMatchFailure( - op, "expected window_dimensions to be [1,x,y,(z),1]"); - } - - Attribute strides; - SmallVector vec; - if (op.getWindowStridesAttr()) { - for (int i = 1; i < lastDim; ++i) { - vec.push_back(op.getWindowStrides().value()[i]); - } - } else { - vec.assign(rank - 2, 1); - } - strides = rewriter.getI64VectorAttr(vec); - - Attribute dilations; - vec.clear(); - if (op.getWindowDilations()) { - for (int i = 1; i < lastDim; ++i) { - vec.push_back(op.getWindowDilations().value()[i]); - } - } else { - vec.assign(rank - 2, 1); - } - dilations = rewriter.getI64VectorAttr(vec); - - SmallVector poolingOps; - - ValueRange operands = adaptor.getInputs(); - ValueRange initValues = adaptor.getInitValues(); - for (auto it : llvm::zip(op.getResults(), operands, initValues)) { - OpResult result = std::get<0>(it); - Value input = std::get<1>(it); - Value initValue = std::get<2>(it); - auto resultType = cast(result.getType()); - if (!cast(input.getType()).getElementType().isF32()) { - return rewriter.notifyMatchFailure(op, - "expected element type to be f32"); - } - - // Create a fake window dimension. - auto fakeWindowDims = rewriter.create( - loc, fakeWindowShapes, resultType.getElementType()); - - SmallVector resultDynamicDims; - for (const auto &en : llvm::enumerate(resultType.getShape())) { - if (!ShapedType::isDynamic(en.value())) - continue; - Value dimSize = rewriter.create(loc, input, en.index()); - if (en.index() == 0 || static_cast(en.index()) == rank - 1) { - // batch dims and channel dims can be derived from input dims - // directly. - resultDynamicDims.push_back(dimSize); - } else { - auto i = en.index() - 1; - auto stride = - llvm::cast(strides).getValues()[i]; - auto dilation = llvm::cast(dilations) - .getValues()[i]; - // let j = i * stride - // output[i] = reduce( input[j, j + window_size * dilation) ) - Value offset = rewriter.create( - loc, fakeWindowShapes[i] * dilation); - dimSize = rewriter.create(loc, dimSize, offset); - dimSize = rewriter.create( - loc, dimSize, - rewriter.create(loc, stride)); - dimSize = rewriter.create( - loc, dimSize, rewriter.create(loc, 1)); - resultDynamicDims.push_back(dimSize); - } - } - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), - resultDynamicDims); - - initValue = rewriter.create(loc, initValue); - Value filledInitTensor = - rewriter.create(loc, initValue, emptyTensor) - .getResult(0); - auto createOp = [&](auto *typePtr) -> linalg::LinalgOp { - return cast( - rewriter - .create>( - loc, ArrayRef{resultType}, - ValueRange{input, fakeWindowDims.getResult()}, - filledInitTensor, strides, dilations, - linalg::getPrunedAttributeList(op)) - .getOperation()); - }; - linalg::LinalgOp poolingOp; - PoolingType poolingType = getPoolingType(op, result.getResultNumber()); - switch (poolingType) { - case PoolingType::k2DMin: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k3DMin: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k2DMax: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k3DMax: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k2DAdd: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k3DAdd: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::kInvalid: - return rewriter.notifyMatchFailure(op, "unknown reduction operation"); - } - poolingOps.push_back(poolingOp->getResult(0)); - } - rewriter.replaceOp(op, poolingOps); - return success(); - } -}; - -} // namespace - -namespace detail { -void populateStableHloReductionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps) { - if (enablePrimitiveOps) { - patterns->add(typeConverter, context); - } else { - patterns->add(typeConverter, context); - } - patterns->add(typeConverter, - context); - - // Ensure specialized patterns are higher priority than their generic - // versions. - patterns->add(typeConverter, context, - PatternBenefit(2)); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp b/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp deleted file mode 100644 index 30f0bf28cf4f..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" - -#include - -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" - -namespace mlir::iree_compiler::stablehlo { - -namespace { - -Type convertInteger(IntegerType intType) { - return IntegerType::get(intType.getContext(), - intType.getIntOrFloatBitWidth()); -} - -Type convertShapedType(ShapedType shapedType) { - if (auto intType = llvm::dyn_cast(shapedType.getElementType())) - return shapedType.clone(convertInteger(intType)); - return shapedType; -} - -Value materializeCastFromIllegal(OpBuilder &builder, Type type, - ValueRange inputs, Location loc) { - Type fromType = getElementTypeOrSelf(inputs[0].getType()); - Type toType = getElementTypeOrSelf(type); - if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || - !toType.isSignlessInteger()) - return Value(); - // Use unrealized conversion casts to do signful->signless conversions. - return builder.create(loc, type, inputs[0]) - ->getResult(0); -} - -Value materializeCastToIllegal(OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - Type fromType = getElementTypeOrSelf(inputs[0].getType()); - Type toType = getElementTypeOrSelf(type); - if (!fromType.isSignlessInteger() || - (!toType.isSignedInteger() && !toType.isUnsignedInteger())) - return Value(); - // Use unrealized conversion casts to do signless->signful conversions. - return builder.create(loc, type, inputs[0]) - ->getResult(0); -} - -Value scalarToTensor(OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - assert(inputs.size() == 1); - if (llvm::isa(inputs.front().getType())) { - return Value(); - } - auto tensor = - builder - .create( - loc, RankedTensorType::get({}, inputs.front().getType()), - inputs.front()) - .getResult(); - return builder.create(loc, type, tensor) - .getResult(0); -} - -} // namespace - -RemoveSignTypeConverter::RemoveSignTypeConverter() { - addConversion([](Type type) { return type; }); - - addConversion(convertInteger); - addConversion(convertShapedType); - - addArgumentMaterialization(materializeCastToIllegal); - addSourceMaterialization(materializeCastToIllegal); - addTargetMaterialization(materializeCastFromIllegal); -} - -LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() { - addArgumentMaterialization(scalarToTensor); - addSourceMaterialization(scalarToTensor); - addTargetMaterialization(scalarToTensor); -} - -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.h b/compiler/plugins/input/StableHLO/Conversion/TypeConversion.h deleted file mode 100644 index d7b4ce03d0dc..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_TYPE_CONVERSION_H -#define IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_TYPE_CONVERSION_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir::iree_compiler::stablehlo { - -// Type converter to use as part of lowerings from dialects that carry signs -// in their types to those that are signless. -class RemoveSignTypeConverter : public TypeConverter { -public: - RemoveSignTypeConverter(); -}; - -// Type converter which adds additional materializations (beyond signless) -// that are needed as part of the HloToLinalg conversion patterns. -// This is the type converter used by the test pass and is the sanctioned -// way to use the underlying patterns. -class LinalgTypeConverter : public RemoveSignTypeConverter { -public: - LinalgTypeConverter(); -}; - -} // namespace mlir::iree_compiler::stablehlo - -#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_TYPE_CONVERSION_H diff --git a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel index 774e14f75eec..c8f2420d6ae4 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel @@ -25,14 +25,7 @@ iree_lit_test_suite( "legalize_shape_computations.mlir", "stablehlo_custom_calls.mlir", "stablehlo_to_iree_input_dialects.mlir", - "stablehlo_to_linalg_convolution.mlir", - "stablehlo_to_linalg_dot_prod.mlir", "stablehlo_to_linalg_ext.mlir", - "stablehlo_to_linalg_gather.mlir", - "stablehlo_to_linalg_pointwise.mlir", - "stablehlo_to_linalg_random.mlir", - "stablehlo_to_linalg_reduce.mlir", - "stablehlo_to_linalg.mlir", "verify_compiler_input_legality.mlir", "vhlo_stablehlo_mix_invalid.mlir", ], diff --git a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt index 6d9c9b1714d5..5f7202150f84 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt +++ b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt @@ -23,14 +23,7 @@ iree_lit_test_suite( "legalize_shape_computations.mlir" "stablehlo_custom_calls.mlir" "stablehlo_to_iree_input_dialects.mlir" - "stablehlo_to_linalg.mlir" - "stablehlo_to_linalg_convolution.mlir" - "stablehlo_to_linalg_dot_prod.mlir" "stablehlo_to_linalg_ext.mlir" - "stablehlo_to_linalg_gather.mlir" - "stablehlo_to_linalg_pointwise.mlir" - "stablehlo_to_linalg_random.mlir" - "stablehlo_to_linalg_reduce.mlir" "verify_compiler_input_legality.mlir" "vhlo_stablehlo_mix_invalid.mlir" TOOLS diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg.mlir deleted file mode 100644 index cd8a2b0f6de2..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg.mlir +++ /dev/null @@ -1,1618 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// RUN: iree-opt %s --iree-stablehlo-to-linalg="enable-primitive-ops=true" \ -// RUN: --split-input-file --canonicalize | \ -// RUN: FileCheck %s --check-prefix=CHECK-PRIMITIVE - -// CHECK-LABEL: func @bitcast_convert -func.func @bitcast_convert(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> - func.return %result : tensor<2x2xf32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK-LABEL: func @bitcast_convert_dynamic -func.func @bitcast_convert_dynamic(%input: tensor) -> tensor { - %result = "stablehlo.bitcast_convert"(%input) : (tensor) -> tensor - func.return %result : tensor -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @bitcast_convert_expand -func.func @bitcast_convert_expand(%input: tensor<6xi32>) -> tensor<6x4xi8> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<6xi32>) -> tensor<6x4xi8> - func.return %result : tensor<6x4xi8> -} - -// CHECK: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: tensor.empty() : tensor<6x4xi8> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "parallel"]} -// CHECK: ^bb0(%[[IN:.*]]: i32, %[[OUT:.*]]: i8): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i32 -// CHECK: %[[SHIFT:.*]] = arith.shrui %[[IN]], %[[AMT]] : i32 -// CHECK: %[[TRUNC:.*]] = arith.trunci %[[SHIFT]] : i32 to i8 -// CHECK: linalg.yield %[[TRUNC]] : i8 -// CHECK: } -> tensor<6x4xi8> -// CHECK: return %[[RESULT]] : tensor<6x4xi8> - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @bitcast_convert_contract -func.func @bitcast_convert_contract(%input: tensor<7x4xi8>) -> tensor<7xi32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<7x4xi8>) -> tensor<7xi32> - func.return %result : tensor<7xi32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<7xi32> -// CHECK: linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<7xi32>) -> tensor<7xi32> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "reduction"]} -// CHECK: ^bb0(%[[IN:.*]]: i8, %[[OUT:.*]]: i32): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i3 -// CHECK: %[[EXT:.*]] = arith.extui %[[IN]] : i8 to i32 -// CHECK: %[[SHIFT:.*]] = arith.shli %[[EXT]], %[[AMT]] : i32 -// CHECK: %[[OR:.*]] = arith.ori %[[SHIFT]], %[[OUT]] : i32 -// CHECK: linalg.yield %[[OR]] : i32 -// CHECK: } -> tensor<7xi32> -// CHECK: return %[[RESULT]] : tensor<7xi32> - -// ----- - -// CHECK-LABEL: func @concatenate( -// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor -// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[C1]] : tensor -// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[C1]] : tensor -// CHECK: %[[VAL_15:.*]] = tensor.dim %[[VAL_2]], %[[C1]] : tensor -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_7]], %[[VAL_9]] : index -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index -// CHECK: %[[VAL_23:.*]] = tensor.empty(%[[VAL_5]], %[[VAL_17]]) : tensor -// CHECK: %[[VAL_24:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_23]] : tensor) { -// CHECK: ^bb0(%[[VAL_25:.*]]: i32): -// CHECK: %[[VAL_26:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_28:.*]] = linalg.index 1 : index -// CHECK: %[[VAL_30:.*]] = tensor.dim %[[VAL_0]], %[[C1]] : tensor -// CHECK: %[[VAL_32:.*]] = arith.cmpi ult, %[[VAL_28]], %[[VAL_30]] : index -// CHECK: %[[VAL_33:.*]] = scf.if %[[VAL_32]] -> (i32) { -// CHECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_0]][%[[VAL_26]], %[[VAL_28]]] : tensor -// CHECK: scf.yield %[[VAL_35]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_37:.*]] = tensor.dim %[[VAL_1]], %[[C1]] : tensor -// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_30]], %[[VAL_37]] : index -// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_28]], %[[VAL_38]] : index -// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (i32) { -// CHECK: %[[VAL_41:.*]] = arith.subi %[[VAL_28]], %[[VAL_30]] : index -// CHECK: %[[VAL_42:.*]] = tensor.extract %[[VAL_1]][%[[VAL_26]], %[[VAL_41]]] : tensor -// CHECK: scf.yield %[[VAL_42]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_43:.*]] = arith.subi %[[VAL_28]], %[[VAL_38]] : index -// CHECK: %[[VAL_44:.*]] = tensor.extract %[[VAL_2]][%[[VAL_26]], %[[VAL_43]]] : tensor -// CHECK: scf.yield %[[VAL_44]] : i32 -// CHECK: } -// CHECK: scf.yield %[[VAL_45:.*]] : i32 -// CHECK: } -// CHECK: linalg.yield %[[VAL_46:.*]] : i32 -// CHECK: } -> tensor -// CHECK: return %[[VAL_47:.*]] : tensor -// CHECK: } -func.func @concatenate(%a: tensor, %b: tensor, %c: tensor) -> tensor { - %concat = "stablehlo.concatenate"(%a, %b, %c) { - dimension = 1 - } : (tensor, tensor, tensor) -> tensor - func.return %concat : tensor -} - -// ----- - -// CHECK-LABEL: func @concatenate_unsigned( -// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor to tensor -// CHECK-DAG: %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : tensor to tensor -// CHECK-DAG: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : tensor to tensor -// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_3]], %[[C0]] : tensor -// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_3]], %[[C1]] : tensor -// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_4]], %[[C1]] : tensor -// CHECK: %[[VAL_18:.*]] = tensor.dim %[[VAL_5]], %[[C1]] : tensor -// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_10]], %[[VAL_14]] : index -// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index -// CHECK: %[[VAL_26:.*]] = tensor.empty(%[[VAL_8]], %[[VAL_20]]) : tensor -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_26]] : tensor) { -// CHECK: ^bb0(%[[VAL_28:.*]]: i32): -// CHECK: %[[VAL_29:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_30:.*]] = linalg.index 1 : index -// CHECK: %[[VAL_33:.*]] = tensor.dim %[[VAL_3]], %[[C1]] : tensor -// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_33]] : index -// CHECK: %[[VAL_36:.*]] = scf.if %[[VAL_35]] -> (i32) { -// CHECK: %[[VAL_38:.*]] = tensor.extract %[[VAL_3]][%[[VAL_29]], %[[VAL_30]]] : tensor -// CHECK: scf.yield %[[VAL_38]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_40:.*]] = tensor.dim %[[VAL_4]], %[[C1]] : tensor -// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_33]], %[[VAL_40]] : index -// CHECK: %[[VAL_42:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_41]] : index -// CHECK: %[[VAL_43:.*]] = scf.if %[[VAL_42]] -> (i32) { -// CHECK: %[[VAL_44:.*]] = arith.subi %[[VAL_30]], %[[VAL_33]] : index -// CHECK: %[[VAL_45:.*]] = tensor.extract %[[VAL_4]][%[[VAL_29]], %[[VAL_44]]] : tensor -// CHECK: scf.yield %[[VAL_45]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_46:.*]] = arith.subi %[[VAL_30]], %[[VAL_41]] : index -// CHECK: %[[VAL_47:.*]] = tensor.extract %[[VAL_5]][%[[VAL_29]], %[[VAL_46]]] : tensor -// CHECK: scf.yield %[[VAL_47]] : i32 -// CHECK: } -// CHECK: scf.yield %[[VAL_48:.*]] : i32 -// CHECK: } -// CHECK: linalg.yield %[[VAL_49:.*]] : i32 -// CHECK: } -> tensor -// CHECK: %[[VAL_50:.*]] = builtin.unrealized_conversion_cast %[[VAL_51:.*]] : tensor to tensor -// CHECK: return %[[VAL_50]] : tensor -// CHECK: } -func.func @concatenate_unsigned(%a: tensor, %b: tensor, %c: tensor) -> tensor { - %concat = "stablehlo.concatenate"(%a, %b, %c) { - dimension = 1 - } : (tensor, tensor, tensor) -> tensor - func.return %concat : tensor -} - -// ----- - -// CHECK-LABEL: func @constant -// CHECK: %[[CONSTANT:.*]] = arith.constant dense<10> : tensor -func.func @constant() -> tensor { - %result = "stablehlo.constant"() { - value = dense<10> : tensor - } : () -> (tensor) - func.return %result : tensor -} - -// ----- - -// CHECK-LABEL: func @elided_constant -// CHECK: %[[CONSTANT:.*]] = arith.constant dense_resource<__elided__> : tensor<1024xf32> -func.func @elided_constant() -> tensor<1024xf32> { - %result = "stablehlo.constant"() { - value = dense_resource<__elided__> : tensor<1024xf32> - } : () -> (tensor<1024xf32>) - func.return %result : tensor<1024xf32> -} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_basic -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x5x6xf32>) -func.func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "ijk,ikm->ijm", someattr}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> - func.return %0 : tensor<3x4x6xf32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<3x4x6xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5xf32>, tensor<3x5x6xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<3x4x6xf32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @einsum_pointwisemul -func.func @einsum_pointwisemul(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "abc,abc->abc"} : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> - func.return %0 : tensor<3x4x5xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x4x5xf32>) -// CHECK: tensor.empty() : tensor<3x4x5xf32> -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5xf32>, tensor<3x4x5xf32>) -// CHECK-SAME: outs(%[[DST:.+]] : tensor<3x4x5xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[RES:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: func @einsum_matmul -func.func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tensor<7x5xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "ae,ed->ad"}: (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> - func.return %0 : tensor<7x5xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<7x9xf32>, %[[RHS:.*]]: tensor<9x5xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<7x5xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<7x9xf32>, tensor<9x5xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<7x5xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d5)> -// CHECK: func @einsum_broadcast4 -func.func @einsum_broadcast4(%arg0: tensor<3x4x5x6x7xf32>, %arg1: tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "abcdh,hg->abcdg"}: (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> - func.return %0 : tensor<3x4x5x6x8xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5x6x7xf32>, %[[RHS:.*]]: tensor<7x8xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<3x4x5x6x8xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<3x4x5x6x8xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_ellipsis -func.func @einsum_ellipsis(%arg0: tensor<1x512x128xf32>, %arg1: tensor<128x256xf32>) -> tensor<1x512x256xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "...x,xy->...y"} : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32> - func.return %0 : tensor<1x512x256xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<1x512x128xf32>, %[[RHS:.*]]: tensor<128x256xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<1x512x256xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<1x512x128xf32>, tensor<128x256xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x512x256xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_dynamic_size_broadcast_dot -func.func @einsum_dynamic_size_broadcast_dot(%arg0: tensor, %arg1: tensor<4x?xf32>) -> tensor { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "abc,cd->abd"} : (tensor, tensor<4x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor<4x?xf32>) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[LHS]], %[[C0]] : tensor -// CHECK: %[[DIM1:.+]] = tensor.dim %[[LHS]], %[[C1]] : tensor -// CHECK: %[[DIM2:.+]] = tensor.dim %[[RHS]], %[[C1:.+]] : tensor<4x?xf32> -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]]) : tensor -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor, tensor<4x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK: func @broadcast_in_dim -func.func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> - func.return %0 : tensor<7x10x6x4x5xf32> -} -// CHECK: tensor.empty() : tensor<7x10x6x4x5xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim -// CHECK-PRIMITIVE: tensor.collapse_shape -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE: permutation = [1, 0] -// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6x4x5xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [1, 2, 3] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK: func @broadcast_in_dim_ui32 -func.func @broadcast_in_dim_ui32(%operand: tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32> - func.return %0 : tensor<7x10x6x4x5xui32> -} -// CHECK: builtin.unrealized_conversion_cast %{{.*}} : tensor<5x7x1xui32> to tensor<5x7x1xi32> -// CHECK: tensor.empty() : tensor<7x10x6x4x5xi32> -// CHECK: %[[RES:.*]] = linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: i32, %{{.*}}: i32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : i32 -// CHECK: builtin.unrealized_conversion_cast %[[RES]] : tensor<7x10x6x4x5xi32> to tensor<7x10x6x4x5xui32> - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_ui32 -// CHECK-PRIMITIVE: tensor.collapse_shape -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE: permutation = [1, 0] -// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6x4x5xi32> -// CHECK-PRIMITIVE: %[[RES:.*]] = linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [1, 2, 3] -// CHECK-PRIMITIVE: builtin.unrealized_conversion_cast %[[RES]] : tensor<7x10x6x4x5xi32> to tensor<7x10x6x4x5xui32> - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @broadcast_in_dim_with_one_to_one -func.func @broadcast_in_dim_with_one_to_one( - %operand: tensor<1xf32>) -> tensor<1x5xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<1xf32>) -> tensor<1x5xf32> - func.return %0 : tensor<1x5xf32> -} -// CHECK: tensor.empty() : tensor<1x5xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_with_one_to_one -// CHECK-PRIMITIVE-NOT: tensor.collapse_shape -// CHECK-PRIMITIVE-NOT: linalg.transpose -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [1] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d1)> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @broadcast_in_dim_with_transpose -func.func @broadcast_in_dim_with_transpose( - %operand: tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> - func.return %0 : tensor<3x4x2x5xf32> -} -// CHECK: tensor.empty() : tensor<3x4x2x5xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_with_transpose -// CHECK-PRIMITIVE: tensor.empty() : tensor<3x4x2xf32> -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE: permutation = [1, 2, 0] -// CHECK-PRIMITIVE: tensor.empty() : tensor<3x4x2x5xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [3] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @broadcast_in_dim_scalar -func.func @broadcast_in_dim_scalar(%operand: tensor) -> tensor<7x10x6xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor) -> tensor<7x10x6xf32> - func.return %0 : tensor<7x10x6xf32> -} -// CHECK: tensor.empty() : tensor<7x10x6xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_scalar -// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [0, 1, 2] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @broadcast_scalar -func.func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { - %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array} : (tensor) -> tensor<4x2x1xf32> - func.return %0: tensor<4x2x1xf32> -} -// CHECK: tensor.empty() : tensor<4x2x1xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_scalar -// CHECK-PRIMITIVE: tensor.empty() : tensor<4x2x1xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE-SAME: ins( -// CHECK-PRIMITIVE-SAME: outs( -// CHECK-PRIMITIVE-SAME: dimensions = [0, 1, 2] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: func @broadcast -func.func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { - %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> - func.return %0: tensor<4x2x1x4x?x16xf32> -} -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> -// CHECK: %{{.*}} = tensor.empty(%[[DIM]]) : tensor<4x2x1x4x?x16xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast -// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-PRIMITIVE: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> -// CHECK-PRIMITIVE: %{{.*}} = tensor.empty(%[[DIM]]) : tensor<4x2x1x4x?x16xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [0, 1, 2] - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_f32 -func.func @iota_f32() -> tensor<7x10xf32> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64, someattr} : () -> (tensor<7x10xf32>) - func.return %result : tensor<7x10xf32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-SAME: {someattr} -// CHECK-NEXT: ^bb0(%{{.*}}: f32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @iota_f32 -// CHECK-PRIMITIVE: %[[EMPTY:.*]] = tensor.empty() -// CHECK-PRIMITIVE: linalg.map outs(%[[EMPTY]] : tensor<7x10xf32> -// CHECK-PRIMITIVE-SAME: {someattr} -// CHECK-PRIMITIVE: %[[INDEX:.*]] = linalg.index 1 -// CHECK-PRIMITIVE-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i64 -// CHECK-PRIMITIVE-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i64 to f32 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[FLOAT_CAST]] - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_i32 -func.func @iota_i32() -> tensor<7x10xi32> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xi32>) - func.return %result : tensor<7x10xi32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: i32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32 - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_ui32 -func.func @iota_ui32() -> tensor<7x10xui32> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xui32>) - func.return %result : tensor<7x10xui32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: i32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32 -// CHECK: builtin.unrealized_conversion_cast %{{.*}} : tensor<7x10xi32> to tensor<7x10xui32> - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_complexf32 -func.func @iota_complexf32() -> tensor<7x10xcomplex> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xcomplex>) - func.return %result : tensor<7x10xcomplex> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: complex): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i32 to f32 -// CHECK-NEXT: %[[COMPLEX_CAST:.*]] = complex.create %[[FLOAT_CAST]], %[[ZERO]] : complex -// CHECK-NEXT: linalg.yield %[[COMPLEX_CAST]] : complex - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @dynamic_iota_f32 -// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> -func.func @dynamic_iota_f32(%shape: tensor<3xi32>) -> tensor { - %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor) - func.return %result : tensor -} -// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] -// CHECK: %[[I1:.*]] = arith.index_cast %[[V1]] : i32 to index -// CHECK: %[[V2:.*]] = tensor.extract %[[SHAPE]][%c1] -// CHECK: %[[I2:.*]] = arith.index_cast %[[V2]] : i32 to index -// CHECK: tensor.empty(%[[I1]], %[[I2]]) : tensor -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: f32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @dynamic_iota_ui32 -// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> -func.func @dynamic_iota_ui32(%shape: tensor<3xi32>) -> tensor { - %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor) - func.return %result : tensor -} -// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] -// CHECK: %[[I1:.*]] = arith.index_cast %[[V1]] : i32 to index -// CHECK: %[[V2:.*]] = tensor.extract %[[SHAPE]][%c1] -// CHECK: %[[I2:.*]] = arith.index_cast %[[V2]] : i32 to index -// CHECK: tensor.empty(%[[I1]], %[[I2]]) : tensor -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: i32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : i32 -// CHECK: builtin.unrealized_conversion_cast %{{.*}} : tensor to tensor - -// ----- - -// CHECK-LABEL: @map_mixed -// CHECK-PRIMITIVE-LABEL: @map_mixed -func.func @map_mixed(%arg0: tensor, - %arg1: tensor<4xf32>) -> tensor { - %0 = "stablehlo.map"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} - : (tensor, tensor<4xf32>) -> tensor - func.return %0 : tensor -} - -// CHECK: linalg.generic -// CHECK: %[[ADD:.+]] = arith.addf -// CHECK: linalg.yield %[[ADD]] : f32 - -// CHECK-PRIMITIVE: linalg.map { arith.addf } - -// ----- - -// CHECK-LABEL: @map_one_arg -// CHECK-PRIMITIVE-LABEL: @map_one_arg -func.func @map_one_arg(%arg0: tensor) -> tensor { - %0 = "stablehlo.map"(%arg0) ({ - ^bb0(%arg2: tensor): - %1 = stablehlo.add %arg2, %arg2 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} - : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK: linalg.generic -// CHECK: %[[ADD:.+]] = arith.addf -// CHECK: linalg.yield %[[ADD]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: %[[ADD:.+]] = arith.addf -// CHECK-PRIMITIVE: linalg.yield %[[ADD]] : f32 - -// ----- - -// CHECK-LABEL: @map_compare -// CHECK-SAME: %[[ARG0:.*]]: tensor>, -// CHECK-SAME: %[[ARG1:.*]]: tensor>) -// CHECK-PRIMITIVE-LABEL: @map_compare -// CHECK-PRIMITIVE-SAME: %[[ARG0:.*]]: tensor>, -// CHECK-PRIMITIVE-SAME: %[[ARG1:.*]]: tensor>) -func.func @map_compare(%arg0: tensor>, - %arg1: tensor>) -> tensor { - %0 = "stablehlo.map"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor>, %arg3: tensor>): - %1 = stablehlo.real %arg2 : (tensor>) -> tensor - %2 = stablehlo.real %arg3 : (tensor>) -> tensor - %3 = "stablehlo.compare"(%1, %2) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - "stablehlo.return"(%3) : (tensor) -> () - }) {dimensions = array} - : (tensor>, tensor>) -> tensor - func.return %0 : tensor -} - -// CHECK: %[[INIT:.+]] = tensor.empty -// CHECK: %[[MAP:.+]] = linalg.generic -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-SAME: outs(%[[INIT]] : tensor) { -// CHECK: ^bb0(%[[A:.+]]: complex, %[[B:.+]]: complex, %{{.+}}: i1): -// CHECK: %[[RE1:.+]] = complex.re %[[A]] : complex -// CHECK: %[[RE2:.+]] = complex.re %[[B]] : complex -// CHECK: %[[CMP:.+]] = arith.cmpf oeq, %[[RE1]], %[[RE2]] : f32 -// CHECK: linalg.yield %[[CMP]] : i1 -// CHECK: } -// CHECK: return %[[MAP]] : tensor - -// CHECK-PRIMITIVE: %[[INIT:.+]] = tensor.empty -// CHECK-PRIMITIVE: %[[MAP:.+]] = linalg.map -// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) -// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex, %[[B:.+]]: complex) { -// CHECK-PRIMITIVE: %[[RE1:.+]] = complex.re %[[A]] : complex -// CHECK-PRIMITIVE: %[[RE2:.+]] = complex.re %[[B]] : complex -// CHECK-PRIMITIVE: %[[CMP:.+]] = arith.cmpf oeq, %[[RE1]], %[[RE2]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[CMP]] : i1 -// CHECK-PRIMITIVE: } -// CHECK-PRIMITIVE: return %[[MAP]] : tensor - -// ----- - -func.func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { - %0 = arith.constant dense<0.0> : tensor - %1 = "stablehlo.pad"(%arg0, %0) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xf32>, tensor) -> tensor<18x12xf32> - func.return %1 : tensor<18x12xf32> -} -// CHECK-LABEL: func @pad_cst -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: tensor.pad %[[ARG0]] low[4, 5] high[2, 3] -// CHECK: tensor.yield %[[CST]] : f32 -// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32> - -// ----- - -func.func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf32> { - %0 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xf32>, tensor) -> tensor<18x12xf32> - func.return %0 : tensor<18x12xf32> -} -// CHECK-LABEL: func @pad_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: tensor.pad %[[ARG0]] low[4, 5] high[2, 3] -// CHECK: tensor.yield %[[PAD]] : f32 -// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32> - -// ----- - -func.func @pad_interior(%arg0: tensor<12x4xui32>, %arg1: tensor) -> tensor<29x15xui32> { - %0 = arith.constant dense<0> : tensor - %1 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xui32>, tensor) -> tensor<29x15xui32> - func.return %1 : tensor<29x15xui32> -} -// CHECK-LABEL: func @pad_interior -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[CAST0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<12x4xui32> to tensor<12x4xi32> -// CHECK-DAG: %[[CAST1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : tensor to tensor -// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[CAST1]][] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<29x15xi32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PAD]] : i32) outs(%[[INIT]] : tensor<29x15xi32>) -> tensor<29x15xi32> -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[CAST0]] into %[[FILL]][4, 5] [12, 4] [2, 2] : tensor<12x4xi32> into tensor<29x15xi32> - -// ----- - -func.func @pad_interior_negative(%arg0: tensor<12x4xui32>, %arg1: tensor) -> tensor<25x9xui32> { - %0 = arith.constant dense<0> : tensor - %1 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xui32>, tensor) -> tensor<25x9xui32> - func.return %1 : tensor<25x9xui32> -} -// CHECK-LABEL: func @pad_interior_negative -// CHECK: %[[PAD:.*]] = tensor.insert_slice %{{.+}} into %{{.+}}[4, 0] [12, 4] [2, 2] : tensor<12x4xi32> into tensor<29x10xi32> -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD]][0, 1] [25, 9] [1, 1] : tensor<29x10xi32> to tensor<25x9xi32> - -// ----- - -// CHECK-LABEL: func @real_dynamic_slice -// CHECK-SAME: (%[[OPERAND:.*]]: tensor<256x?xf32>, %[[START_INDICES:.*]]: tensor<2xindex>, %[[LIMIT_INDICES:.*]]: tensor<2xindex>, %[[STRIDES:.*]]: tensor<2xindex>) -func.func @real_dynamic_slice(%input: tensor<256x?xf32>, %start_indices: tensor<2xindex>, %limit_indices: tensor<2xindex>, %strides: tensor<2xindex>) -> tensor<256x?xf32> { - %0 = "stablehlo.real_dynamic_slice"(%input, %start_indices, %limit_indices, %strides) : (tensor<256x?xf32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<256x?xf32> - func.return %0 : tensor<256x?xf32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 - -// Fetch start index, limit index and stride. -// CHECK-DAG: %[[START0:.*]] = tensor.extract %[[START_INDICES]][%[[C0]]] -// CHECK-DAG: %[[STRIDE0:.*]] = tensor.extract %[[STRIDES]][%[[C0]]] - -// Clamp starting index : 0 <= start <= ub -// CHECK-DAG: %[[MAX0:.*]] = arith.maxsi %[[START0]], %[[C0]] : index -// CHECK-DAG: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C0]] : index - -// CHECK-DAG: %[[START1:.*]] = tensor.extract %[[START_INDICES]][%[[C1]]] -// CHECK-DAG: %[[LIMIT1:.*]] = tensor.extract %[[LIMIT_INDICES]][%[[C1]]] -// CHECK-DAG: %[[STRIDE1:.*]] = tensor.extract %[[STRIDES]][%[[C1]]] - -// 2.2. Since 1-th dimension of result is unknown we compute result size at 1-th -// dimension as size[1] = (limit - start)/stride -// CHECK-DAG: %[[DELTA1:.*]] = arith.subi %[[LIMIT1]], %[[START1]] : index -// CHECK-DAG: %[[SIZE1:.*]] = arith.ceildivui %[[DELTA1]], %[[STRIDE1]] : index - -// 2.3. Compute upper bound for starting index = operand_dim[1] - size[1]. -// where, size[1] is computed at step 2.2 -// CHECK-DAG: %[[OPERAND_DIM1:.*]] = tensor.dim %[[OPERAND]], %[[C1]] : tensor<256x?xf32> -// CHECK-DAG: %[[UB:.*]] = arith.subi %[[OPERAND_DIM1]], %[[SIZE1]] : index - -// 2.4. Clamp starting index : 0 <= start <= ub -// where upper bound (ub) is computed at step 2.3 -// CHECK-DAG: %[[MAX1:.*]] = arith.maxsi %[[START1]], %[[C0]] : index -// CHECK-DAG: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[UB]] : index - -// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[OPERAND]][%[[MIN0]], %[[MIN1]]] [256, %[[SIZE1]]] [%[[STRIDE0]], %[[STRIDE1]]] : tensor<256x?xf32> to tensor<256x?xf32> -// CHECK: return %[[SLICE]] : tensor<256x?xf32> - -// ----- - -// Verify that legalization of real_dynamic_slice legalization with integer -// dims work & passes verification. -// CHECK-LABEL: real_dynamic_slice_with_int -func.func @real_dynamic_slice_with_int(%arg0: tensor<10xi32> , %arg1: tensor<1xi32> ) -> tensor { - %0 = stablehlo.constant dense<0> : tensor<1xi32> - %1 = stablehlo.constant dense<1> : tensor<1xi32> - %2 = stablehlo.constant dense<0> : tensor - %4 = "stablehlo.real_dynamic_slice"(%arg0, %0, %arg1, %1) : (tensor<10xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - func.return %4 : tensor -} - -// ----- - -// CHECK-LABEL: func @reshape_0D_1D -func.func @reshape_0D_1D(%arg0: tensor) -> tensor<1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> - func.return %0 : tensor<1xi32> -} -// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1] : tensor into tensor<1xi32> - -// ----- - -// CHECK-LABEL: func @reshape_0D_1D_unsigned -// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] -func.func @reshape_0D_1D_unsigned(%arg0: tensor) -> tensor<1xui32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<1xui32> - func.return %0 : tensor<1xui32> -} -// CHECK: %[[ARG_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor to tensor -// CHECK: %[[RET_SIGNLESS:.*]] = tensor.expand_shape %[[ARG_SIGNLESS]] [] output_shape [1] : tensor into tensor<1xi32> -// CHECK: %[[RET_UNSIGNED:.*]] = builtin.unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32> -// CHECK: return %[[RET_UNSIGNED]] : tensor<1xui32> - -// ----- - -// CHECK-LABEL: func @reshape_1D_0D -func.func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor - -// ----- - -// CHECK-LABEL: func @reshape_1D_0D_unsigned -// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] -func.func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1xui32>) -> tensor - func.return %0 : tensor -} -// CHECK: %[[ARG_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32> -// CHECK: %[[RET_SIGNLESS:.*]] = tensor.collapse_shape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor -// CHECK: %[[RET_UNSIGNED:.*]] = builtin.unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor to tensor -// CHECK: return %[[RET_UNSIGNED]] : tensor - -// ----- - -// CHECK-LABEL: func @reshape_3D_2D -func.func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> - func.return %0 : tensor<12x42xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2]] - -// ----- - -// CHECK-LABEL: func @reshape_4D_2D -func.func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> - func.return %0 : tensor<12x42xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] - -// ----- - -// CHECK-LABEL: func @reshape_2D_4D -func.func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> - func.return %0 : tensor<12x1x42x1xi32> -} -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2, 3]] - -// ----- - -// CHECK-LABEL: func @reshape_3D_4D -func.func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> - func.return %0 : tensor<1x784x1x1xf32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2]] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]] - -// ----- - -// CHECK-LABEL: func @reshape_4D_3D -func.func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> - func.return %0 : tensor<1x240x1xf32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2] - -// ----- - -// CHECK-LABEL: func @reshape1_4D_4D -func.func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> - func.return %0 : tensor<1x4x1x512xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3] - -// ----- - -// CHECK-LABEL: func @reshape2_4D_4D -func.func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> - func.return %0 : tensor<4x1024x1x1xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3] - -// ----- - -// CHECK-LABEL: func @reshape_dynamic_in -func.func @reshape_dynamic_in(%arg0: tensor) -> tensor<2x4x5xf32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<2x4x5xf32> - func.return %0 : tensor<2x4x5xf32> -} -// CHECK: %[[FLATTEN:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] : tensor into tensor -// CHECK: %[[CAST:.*]] = tensor.cast %[[FLATTEN]] : tensor to tensor<40xf32> -// CHECK: tensor.expand_shape %[[CAST]] {{\[}}[0, 1, 2]] output_shape [2, 4, 5] : tensor<40xf32> into tensor<2x4x5xf32> - -// ----- - -// CHECK-LABEL: func @reshape_1D_2D_dynamic -func.func @reshape_1D_2D_dynamic(%arg0: tensor) -> tensor<1x3xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<1x3xi32> - func.return %0 : tensor<1x3xi32> -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor to tensor<3xi32> -// CHECK: tensor.expand_shape %[[CAST]] {{\[}}[0, 1]] output_shape [1, 3] : tensor<3xi32> into tensor<1x3xi32> - -// ----- - -// CHECK-LABEL: func @reshape_2D_1D_dynamic -func.func @reshape_2D_1D_dynamic(%arg0: tensor) -> tensor<3xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} -// CHECK: %[[FLATTEN:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] : tensor into tensor -// CHECK: %[[CAST:.*]] = tensor.cast %[[FLATTEN]] : tensor to tensor<3xi32> -// CHECK: return %[[CAST:.*]] : tensor<3xi32> - -// ----- -// CHECK-LABEL: func @reshape_2D_1D_semidynamic -func.func @reshape_2D_1D_semidynamic(%arg0: tensor<1x?xi32>) -> tensor<1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1x?xi32>) -> tensor<1xi32> - func.return %0 : tensor<1xi32> -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xi32> to tensor<1x1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}[0, 1]] : tensor<1x1xi32> into tensor<1xi32> -// CHECK: return %[[COLLAPSE:.*]] : tensor<1xi32> - -// ----- - -// CHECK-LABEL: func @reshape_1D_0D_dynamic -func.func @reshape_1D_0D_dynamic(%arg0: tensor) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor to tensor<1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}] : tensor<1xi32> into tensor -// CHECK: return %[[COLLAPSE:.*]] : tensor - -// ----- - -// CHECK-LABEL: func @reshape_2D_0D_dynamic -func.func @reshape_2D_0D_dynamic(%arg0: tensor) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor to tensor<1x1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}] : tensor<1x1xi32> into tensor -// CHECK: return %[[COLLAPSE:.*]] : tensor - -// ----- - -// CHECK-LABEL: func @reshape_3D_1D_semidynamic -func.func @reshape_3D_1D_semidynamic(%arg0: tensor<16x1x?xi32>) -> tensor<16xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<16x1x?xi32>) -> tensor<16xi32> - func.return %0 : tensor<16xi32> -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<16x1x?xi32> to tensor<16x1x1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}[0, 1, 2]] : tensor<16x1x1xi32> into tensor<16xi32> -// CHECK: return %[[COLLAPSE:.*]] : tensor<16xi32> - -// ----- - -// CHECK-LABEL: func @reshape_empty -func.func @reshape_empty(%arg0: tensor<7x0xf64>) -> tensor<0x42x101xf64> { - %0 = stablehlo.reshape %arg0 : (tensor<7x0xf64>) -> tensor<0x42x101xf64> - return %0 : tensor<0x42x101xf64> -} - -// CHECK: %[[INIT:.*]] = tensor.empty -// CHECK: return %[[INIT]] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @reverse -func.func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { - %result = "stablehlo.reverse"(%input) { - dimensions = array, someattr - } : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %result : tensor<2x3xf32> -} -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-SAME: {someattr} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: func.func @select_and_scatter -func.func @select_and_scatter(%arg0 : tensor<2x8x8x1xf32>, %arg1 : tensor<2x4x4x1xf32>, %arg2 : tensor) -> tensor<2x8x8x1xf32> { - %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - %9 = stablehlo.compare GE, %arg3, %arg4, FLOAT : (tensor, tensor) -> tensor - stablehlo.return %9 : tensor - }, { - ^bb0(%arg3: tensor, %arg4: tensor): - %9 = stablehlo.add %arg3, %arg4 : tensor - stablehlo.return %9 : tensor - }) { - padding = dense<0> : tensor<4x2xi64>, - window_dimensions = array, - window_strides = array - } : (tensor<2x8x8x1xf32>, tensor<2x4x4x1xf32>, tensor) -> tensor<2x8x8x1xf32> - - return %0 : tensor<2x8x8x1xf32> -} -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[CN1:.+]] = arith.constant -1 : i32 -// CHECK: %[[EMPTY_VAL:.+]] = tensor.empty() : tensor<2x4x4x1xf32> -// CHECK: %[[EMPTY_IDX:.+]] = tensor.empty() : tensor<2x4x4x1xi32> -// CHECK: %[[FILL_IDX:.+]] = linalg.fill ins(%[[CN1]] : i32) outs(%[[EMPTY_IDX]] : tensor<2x4x4x1xi32>) -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<2x2xf32> -// CHECK: %[[SELECT_GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map1, #map2, #map2] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} -// CHECK-SAME: ins(%arg0, %[[WINDOW]] : tensor<2x8x8x1xf32>, tensor<2x2xf32>) -// CHECK-SAME: outs(%[[EMPTY_VAL]], %[[FILL_IDX]] : tensor<2x4x4x1xf32>, tensor<2x4x4x1xi32>) -// CHECK: ^bb0(%[[VAL:.+]]: f32, %[[IDX:.+]]: f32, %[[OLD_VAL:.+]]: f32, %[[OLD_IDX:.+]]: i32): -// CHECK: %[[CMPF:.+]] = arith.cmpf oge, %[[VAL]], %[[OLD_VAL]] : f32 -// CHECK: %[[CMPI:.+]] = arith.cmpi eq, %[[OLD_IDX]], %[[CN1]] : i32 -// CHECK: %[[PRED:.+]] = arith.ori %[[CMPF]], %[[CMPI]] : i1 -// CHECK: %[[IDX4:.+]] = linalg.index 4 : index -// CHECK: %[[IDX5:.+]] = linalg.index 5 : index -// CHECK: %[[MUL:.+]] = arith.muli %[[IDX4]], %[[C2]] : index -// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[IDX5]] : index -// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]] : index to i32 -// CHECK: %[[SEL_IDX:.+]] = arith.select %[[PRED]], %[[CAST]], %[[OLD_IDX]] : i32 -// CHECK: %[[SEL_VAL:.+]] = arith.select %[[PRED]], %[[VAL]], %[[OLD_VAL]] : f32 -// CHECK: linalg.yield %[[SEL_VAL]], %[[SEL_IDX]] : f32, i32 - -// CHECK: %[[SCATTER_EMPTY:.+]] = tensor.empty() : tensor<2x4x2x4x2x1xf32> -// CHECK: %[[INIT:.+]] = tensor.extract %arg2[] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT]] : f32) outs(%[[SCATTER_EMPTY]] : tensor<2x4x2x4x2x1xf32>) -> tensor<2x4x2x4x2x1xf32> -// CHECK: %[[SCATTER:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#map3, #map3, #map4] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} -// CHECK-SAME: ins(%[[SELECT_GENERIC]]#1, %arg1 : tensor<2x4x4x1xi32>, tensor<2x4x4x1xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x4x2x4x2x1xf32>) -// CHECK: ^bb0(%[[IDX:.+]]: i32, %[[UPDATE:.+]]: f32, %[[OLD:.+]]: f32): -// CHECK: %[[NEW:.+]] = arith.addf %[[UPDATE]], %[[OLD]] : f32 -// CHECK: %[[IDX2:.+]] = linalg.index 2 : index -// CHECK: %[[IDX4:.+]] = linalg.index 4 : index -// CHECK: %[[MUL:.+]] = arith.muli %[[IDX2]], %[[C2]] : index -// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[IDX4]] : index -// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]] : index to i32 -// CHECK: %[[CMP:.+]] = arith.cmpi eq, %[[CAST]], %[[IDX]] : i32 -// CHECK: %[[SELECT:.+]] = arith.select %[[CMP]], %[[NEW]], %[[OLD]] : f32 -// CHECK: linalg.yield %[[SELECT]] - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] -// CHECK-NEXT{literal}: [[0], [1, 2], [3, 4], [5]] -// CHECK: return %[[COLLAPSE]] : tensor<2x8x8x1xf32> - -// ----- - -// CHECK-LABEL: set_dimension_size -// CHECK-SAME: %[[VALUE:.*]]: tensor<2x?xf32, #stablehlo.bounds -func.func @set_dimension_size( - %value: tensor<2x?xf32, #stablehlo.type_extensions>, - %dimension: tensor) - -> tensor<2x?xf32, #stablehlo.type_extensions> { - // CHECK: tensor.extract_slice %[[VALUE]][0, 0] [2, %{{.*}}] [1, 1] : tensor<2x?xf32, #stablehlo.bounds> to tensor<2x?xf32, #stablehlo.bounds> - %0 = "stablehlo.set_dimension_size"(%value, %dimension) { dimension = 1 } - : (tensor<2x?xf32, #stablehlo.type_extensions>, tensor) - -> tensor<2x?xf32, #stablehlo.type_extensions> - func.return %0 : tensor<2x?xf32, #stablehlo.type_extensions> -} - -// ----- - -func.func @torch_index_select(%arg0: tensor<5x1x5xi32>, - %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - dim = 0 : i64, - batch_dims = 0 : i64, - someattr - } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> - func.return %0 : tensor<2x1x5xi32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @torch_index_select -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -// CHECK: %[[INIT1:.+]] = tensor.empty() : -// CHECK: %[[INIT2:.+]] = tensor.empty() : -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT1]] : -// CHECK-SAME: outs(%[[INIT2]] : -// CHECK-SAME: {someattr} -// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32, %{{.+}}: i32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[J:.+]] = linalg.index 1 -// CHECK: %[[K:.+]] = linalg.index 2 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> -// CHECK: linalg.yield %[[VAL2]] : i32 - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @torch_index_select_unsigned -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>, - %arg1: tensor<2xi32>) -> tensor<2x1x5xui32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - dim = 0 : i64, - batch_dims = 0 : i64 - } : (tensor<5x1x5xui32>, tensor<2xi32>) -> tensor<2x1x5xui32> - func.return %0 : tensor<2x1x5xui32> -} -// CHECK: %[[INPUT_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32> -// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x5xi32> -// CHECK: %[[RES:.+]] = linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT]] : tensor<2xi32>, tensor<1x5xi32>) -// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32, %{{.+}}: i32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[J:.+]] = linalg.index 1 -// CHECK: %[[K:.+]] = linalg.index 2 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> -// CHECK: linalg.yield %[[VAL2]] : i32 -// CHECK: %[[RES_UNSIGNED:.+]] = builtin.unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32> -// CHECK: return %[[RES_UNSIGNED]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> -// CHECK: func @torch_index_select_scalar -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_scalar(%arg0: tensor<4x8xf32>, - %arg1: tensor) -> tensor<8xf32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - batch_dims = 0 : i64, - dim = 0 : i64 - } : (tensor<4x8xf32>, tensor) -> tensor<8xf32> - func.return %0 : tensor<8xf32> -} -// CHECK: %[[T0:.+]] = tensor.empty() : tensor<8xf32> -// CHECK: %[[T1:.+]] = tensor.empty() : tensor<8xf32> -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP1]] -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[T0]] : tensor, tensor<8xf32>) outs(%[[T1]] : tensor<8xf32>) -// CHECK: ^{{.+}}(%[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[I:.+]] = linalg.index 0 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32> -// CHECK: linalg.yield %[[VAL2]] : f32 - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @torch_index_select_batch -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>, - %arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - dim = 2 : i64, - batch_dims = 1 : i64 - } : (tensor<4x7x8x2xf32>, tensor<4x1xi32>) -> tensor<4x7x1x2xf32> - func.return %0 : tensor<4x7x1x2xf32> -} -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7x2xf32> -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT]] : -// CHECK-NEXT: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: f32, %{{.+}}: f32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[I:.+]] = linalg.index 0 -// CHECK: %[[J:.+]] = linalg.index 1 -// CHECK: %[[L:.+]] = linalg.index 3 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32> -// CHECK: linalg.yield %[[VAL2]] : f32 - -// ----- - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @torch_index_select_dynamic -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_dynamic(%input: tensor, - %index: tensor) -> tensor{ - %0 = "stablehlo.torch_index_select"(%input, %index) { - batch_dims = 1 : i64, - dim = 2 : i64 - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] -// CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] -// CHECK: %[[D2:.+]] = tensor.dim %[[INDEX]], %[[C1]] -// CHECK: %[[D3:.+]] = tensor.dim %[[INPUT]], %[[C3]] -// CHECK: %[[D4:.+]] = tensor.dim %[[INPUT]], %[[C0]] -// CHECK: %[[D5:.+]] = tensor.dim %[[INPUT]], %[[C1]] -// CHECK: %[[D6:.+]] = tensor.dim %[[INPUT]], %[[C3]] -// CHECK: %[[INIT0:.+]] = tensor.empty(%[[D4]], %[[D5]], %[[D6]]) : tensor -// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT0]] : tensor, tensor) -// CHECK-SAME: outs(%[[INIT1]] : tensor) -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32, %{{[a-zA-Z0-9_]+}}: f32) -// CHECK: %[[POS:.+]] = arith.index_cast %[[ARG0]] -// CHECK: %[[IDX0:.+]] = linalg.index 0 -// CHECK: %[[IDX1:.+]] = linalg.index 1 -// CHECK: %[[IDX3:.+]] = linalg.index 3 -// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[IDX0]], %[[IDX1]], %[[POS]], %[[IDX3]]] -// CHECK: linalg.yield %[[YIELD]] - -// ----- - -// CHECK-LABEL: func @slice_whole_stride -// CHECK: tensor.extract_slice %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32> -func.func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { - %0 = "stablehlo.slice"(%arg0) { - start_indices = array, - limit_indices = array, - strides = array - } : (tensor<3x4xi32>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_stride_part -// CHECK: tensor.extract_slice %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32> -func.func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { - %0 = "stablehlo.slice"(%arg0) { - start_indices = array, - limit_indices = array, - strides = array - } : (tensor<3x4xi32>) -> tensor<1x2xi32> - func.return %0 : tensor<1x2xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_with_strides -// CHECK: tensor.extract_slice %{{.*}}[0] [6] [2] : tensor<13xi32> to tensor<6xi32> -func.func @slice_with_strides(%arg0: tensor<13xi32>) -> tensor<6xi32> { - %0 = "stablehlo.slice"(%arg0) { - limit_indices = array, - start_indices = array, - strides = array - } : (tensor<13xi32>) -> tensor<6xi32> - func.return %0 : tensor<6xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_with_strides -// CHECK: tensor.extract_slice %{{.*}}[0] [3] [2] : tensor<6xi32> to tensor<3xi32> -func.func @slice_with_strides2(%arg0: tensor<6xi32>) -> tensor<3xi32> { - %0 = "stablehlo.slice"(%arg0) { - limit_indices = array, - start_indices = array, - strides = array - } : (tensor<6xi32>) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_with_empty_result -// CHECK: tensor.extract_slice %{{.*}}[0, 2, 0] [3, 0, 5] [1, 2, 1] : tensor<3x3x5xf64> to tensor<3x0x5xf64> -func.func @slice_with_empty_result(%arg0: tensor<3x3x5xf64>) -> tensor<3x0x5xf64> { - %0 = "stablehlo.slice"(%arg0) { - limit_indices = array, - start_indices = array, - strides = array - } : (tensor<3x3x5xf64>) -> tensor<3x0x5xf64> - func.return %0 : tensor<3x0x5xf64> -} - -// ----- - -// CHECK-LABEL: func @dynamic_slice( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor) -> tensor<1x4xf32> { - %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = array - } : (tensor<3x4xf32>, tensor, tensor) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C2]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C0]] : index -// CHECK: tensor.extract_slice %[[ARG0]][%[[CLAMPED1]], %[[CLAMPED2]]] [1, 4] [1, 1] - -// ----- - -// CHECK-LABEL: func @dynamic_slice_unsigned_index( -func.func @dynamic_slice_unsigned_index( - %arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) - -> tensor<1x4xui32> { - %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = array - } : (tensor<3x4xui32>, tensor, tensor) -> tensor<1x4xui32> - func.return %0 : tensor<1x4xui32> -} - -// CHECK: %[[EXTRACT1:.*]] = tensor.extract -// CHECK: arith.index_castui %[[EXTRACT1]] - -// ----- - -// CHECK-LABEL: func @dynamic_slice_unsigned( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) -> tensor<1x4xui32> { - %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = array - } : (tensor<3x4xui32>, tensor, tensor) -> tensor<1x4xui32> - func.return %0 : tensor<1x4xui32> -} - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[SIGNLESS_ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x4xui32> to tensor<3x4xi32> -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C2]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C0]] : index -// CHECK: tensor.extract_slice %[[SIGNLESS_ARG0]][%[[CLAMPED1]], %[[CLAMPED2]]] [1, 4] [1, 1] - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_update_slice(%target: tensor<3x3xi32>, %update: tensor<2x2xi32>, %c0: tensor) -> tensor<3x3xi32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %c0, %c0) - : (tensor<3x3xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<3x3xi32> - func.return %0 : tensor<3x3xi32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C1]] : index -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG1]] into %[[ARG0]] -// CHECK-SAME: [%[[CLAMPED1]], %[[CLAMPED2]]] [2, 2] [1, 1] -// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32> -// CHECK: return %[[RES]] : tensor<3x3xi32> - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice_unsigned_index( -func.func @dynamic_update_slice_unsigned_index( - %target: tensor<3x3xi32>, %update: tensor<2x2xi32>, - %idx: tensor) -> tensor<3x3xi32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %idx, %idx) - : (tensor<3x3xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<3x3xi32> - func.return %0 : tensor<3x3xi32> -} - -// CHECK: %[[EXTRACT1:.*]] = tensor.extract -// CHECK: arith.index_castui %[[EXTRACT1]] - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice_unsigned( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: tensor<2x2xui32>, %c0: tensor) -> tensor<3x3xui32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %c0, %c0) - : (tensor<3x3xui32>, tensor<2x2xui32>, tensor, tensor) -> tensor<3x3xui32> - func.return %0 : tensor<3x3xui32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[SIGNLESS_UPDATE:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : tensor<2x2xui32> to tensor<2x2xi32> -// CHECK-DAG: %[[SIGNLESS_TARGET:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x3xui32> to tensor<3x3xi32> -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C1]] : index -// CHECK: %[[SIGNLESS_RES:.*]] = tensor.insert_slice %[[SIGNLESS_UPDATE]] into %[[SIGNLESS_TARGET]] -// CHECK-SAME: [%[[CLAMPED1]], %[[CLAMPED2]]] [2, 2] [1, 1] -// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32> -// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[SIGNLESS_RES]] : tensor<3x3xi32> to tensor<3x3xui32> -// CHECK: return %[[RES]] : tensor<3x3xui32> - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice_float( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_update_slice_float(%target: tensor<3x3xf32>, - %update: tensor<2x2xf32>, - %c0: tensor) -> tensor<3x3xf32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %c0, %c0) - : (tensor<3x3xf32>, tensor<2x2xf32>, tensor, tensor) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C1]] : index -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG1]] into %[[ARG0]] -// CHECK-SAME: [%[[CLAMPED1]], %[[CLAMPED2]]] [2, 2] [1, 1] -// CHECK-SAME: : tensor<2x2xf32> into tensor<3x3xf32> -// CHECK: return %[[RES]] : tensor<3x3xf32> - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3, d2)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @transpose -func.func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - %0 = "stablehlo.transpose"(%arg0) {permutation = array} - : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> - func.return %0 : tensor<3x2x5x9xi32> -} -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] - -// CHECK-PRIMITIVE-LABEL: func @transpose -// CHECK-PRIMITIVE: linalg.transpose - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3, d2)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @transpose_dynamic -func.func @transpose_dynamic(%arg0: tensor) -> tensor { - %0 = "stablehlo.transpose"(%arg0) {permutation = array, someattr} - : (tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] -// CHECK: %[[D1:.*]] = tensor.dim %arg0, %[[C1]] -// CHECK: %[[D3:.*]] = tensor.dim %arg0, %[[C3]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D0]], %[[D3]]) : tensor -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-SAME: ins(%arg0 : tensor) outs(%[[INIT]] : tensor) -// CHECK-SAME: {someattr} - -// CHECK-PRIMITIVE-LABEL: func @transpose_dynamic -// CHECK-PRIMITIVE-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-PRIMITIVE-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-PRIMITIVE: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] -// CHECK-PRIMITIVE: %[[D1:.*]] = tensor.dim %arg0, %[[C1]] -// CHECK-PRIMITIVE: %[[D3:.*]] = tensor.dim %arg0, %[[C3]] -// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D0]], %[[D3]]) : tensor -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE-SAME: ins(%arg0 : tensor) -// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) -// CHECK-PRIMITIVE-SAME: permutation = [1, 0, 3, 2] -// CHECK-PRIMITIVE-SAME: {someattr} - -// ----- - -func.func @transpose_unsigned(%arg0: tensor<2x2xui32>) -> tensor<2x2xui32> { - %0 = "stablehlo.transpose"(%arg0) { - permutation = array, - result_layout = dense<[0, 1]> : tensor<2xindex> - } : (tensor<2x2xui32>) -> tensor<2x2xui32> - return %0 : tensor<2x2xui32> -} - -// Regression test. Just check that unsigned ints lower successfully. -// CHECK-LABEL: func @transpose_unsigned -// CHECK-PRIMITIVE-LABEL: func @transpose_unsigned diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_convolution.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_convolution.mlir deleted file mode 100644 index 6a4702e83031..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_convolution.mlir +++ /dev/null @@ -1,595 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// CHECK-LABEL: @linalg.conv_0d_nc -func.func @linalg.conv_0d_nc(%arg0: tensor<3x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, f]x[i, o]->[b, f], - window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} - { - batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() -// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%cst{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<3x2xf32>, tensor<2x3xf32>) outs(%[[FILL]] : tensor<3x3xf32>) - -// ----- - -// CHECK-LABEL: func @linalg.conv_1d_nwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -func.func @linalg.conv_1d_nwc(%arg0: tensor, %arg1: tensor<2x?x?xf32>) - -> tensor { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0]]> : tensor<1x2xi64>, - rhs_dilation = array, - window_strides = array, - someattr - } : (tensor, tensor<2x?x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32> -// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]], %[[DIM2]]) -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.conv_1d_nwc_wcf -// CHECK-SAME: {dilations = dense<1> : tensor<1xi64> -// CHECK-SAME: someattr -// CHECK-SAME: strides = dense<1> : tensor<1xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<2x?x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @conv_2d_nhwc_hwcf -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_2d_nhwc_hwcf(%arg0: tensor, %arg1: tensor<3x2x?x?xf32>) - -> tensor { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array - } : (tensor, tensor<3x2x?x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> -// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]], %[[DIM3]]) -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.conv_2d_nhwc -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3x2x?x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @conv_transpose_2d -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_transpose_2d(%arg0: tensor<2x9x10x3xf32>, - %arg1: tensor<4x4x3x3xf32>) - -> tensor<2x15x25x3xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[6, 6], [6, 6]], - lhs_dilate = [1, 2], rhs_dilate = [2, 2]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, - #stablehlo] - } : (tensor<2x9x10x3xf32>, tensor<4x4x3x3xf32>) -> tensor<2x15x25x3xf32> - return %0 : tensor<2x15x25x3xf32> -} -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: %[[LHS_INIT:.+]] = tensor.empty() -// CHECK: %[[LHS_FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[LHS_INIT]] -// CHECK: %[[LHS_PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[LHS_FILL]][0, 6, 6, 0] [2, 9, 10, 3] [1, 1, 2, 1] : tensor<2x9x10x3xf32> into tensor<2x21x31x3xf32> -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: {dilations = dense<2> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[LHS_PAD]], %[[ARG1]] : tensor<2x21x31x3xf32>, tensor<4x4x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x15x25x3xf32>) -> tensor<2x15x25x3xf32> - -// ----- - -// CHECK-LABEL: func @conv_transpose_complex_2d -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_transpose_complex_2d(%arg0: tensor<2x9x10x3xcomplex>, - %arg1: tensor<4x4x3x3xcomplex>) - -> tensor<2x15x25x3xcomplex> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[6, 6], [6, 6]], - lhs_dilate = [1, 2], rhs_dilate = [2, 2]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, - #stablehlo] - } : (tensor<2x9x10x3xcomplex>, tensor<4x4x3x3xcomplex>) -> tensor<2x15x25x3xcomplex> - return %0 : tensor<2x15x25x3xcomplex> -} -// CHECK: %[[ZERO:.+]] = complex.constant [0.000000e+00 : f32, 0.000000e+00 : f32] : complex -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: %[[LHS_INIT:.+]] = tensor.empty() -// CHECK: %[[LHS_FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[LHS_INIT]] -// CHECK: %[[LHS_PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[LHS_FILL]][0, 6, 6, 0] [2, 9, 10, 3] [1, 1, 2, 1] : tensor<2x9x10x3xcomplex> into tensor<2x21x31x3xcomplex> -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: {dilations = dense<2> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[LHS_PAD]], %[[ARG1]] : tensor<2x21x31x3xcomplex>, tensor<4x4x3x3xcomplex>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x15x25x3xcomplex>) -> tensor<2x15x25x3xcomplex> - -// ----- - -// Just check that this lowers successfully. -// CHECK-LABEL: func @conv_different_batch_dim_in_out -func.func @conv_different_batch_dim_in_out(%arg0: tensor<1x1x1xf64>, - %arg1: tensor<1x1x1xf64>) - -> tensor<1x1x1xf64> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [f, 0, b]x[i, o, 0]->[f, b, 0], - window = {stride = [1], pad = [[0, 0]], lhs_dilate = [1], - rhs_dilate = [1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x1x1xf64>, tensor<1x1x1xf64>) -> tensor<1x1x1xf64> - return %0 : tensor<1x1x1xf64> -} - -// ----- - -// Just check that this lowers successfully. -// CHECK-LABEL: func @conv_different_batch_dim_in_out_with_feature_group_count -func.func @conv_different_batch_dim_in_out_with_feature_group_count( - %arg0: tensor<4x6x7x1xf64>, %arg1: tensor<2x6x3x2xf64>) - -> tensor<1x2x1x2xf64> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f], - window = {stride = [1, 1], pad = [[0, 0], [0, -1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 2], - reverse = [0, 0]} - { - batch_group_count = 1 : i64, - feature_group_count = 2 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<4x6x7x1xf64>, tensor<2x6x3x2xf64>) -> tensor<1x2x1x2xf64> - return %0 : tensor<1x2x1x2xf64> -} - -// ----- - -// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor, %arg1: tensor<2x2x2x?x?xf32>) - -> tensor { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>, - rhs_dilation = array, - window_strides = array - } : (tensor, tensor<2x2x2x?x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM4:.+]] = tensor.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32> -// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]], %[[DIM4]]) -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.conv_3d_ndhwc_dhwcf -// CHECK-SAME: {dilations = dense<1> : tensor<3xi64> -// CHECK-SAME: strides = dense<1> : tensor<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<2x2x2x?x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>) - -> tensor<1x2x4x3xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array - } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32> - func.return %0 : tensor<1x2x4x3xf32> -} -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x4x3xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[INIT]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32> -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32> - -// ----- - -// CHECK-LABEL: func @linalg.conv_2D_padding_test1 -// CHECK-SAME: (%[[FILTER:.*]]: tensor<1x33x1x1xf16>, %[[INPUT:.*]]: tensor<400x1024x1024x1xf16>) -func.func @linalg.conv_2D_padding_test1(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<400x1024x1024x1xf16>) - -> tensor<400x1024x1024x1xf16> { - %0 = stablehlo.convolution(%arg1, %arg0) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], pad = [[0, 0], [16, 16]], rhs_dilate = [1, 1] } - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64 - } : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1024x1024x1xf16>) - func.return %0 : tensor<400x1024x1024x1xf16> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16 -// CHECK-NEXT: %[[INIT:.*]] = tensor.empty() : tensor<400x1024x1024x1xf16> -// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16> -// CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 0, 16, 0] high[0, 0, 16, 0] { -// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK-NEXT: tensor.yield %[[ZERO]] : f16 -// CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1024x1056x1xf16> -// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[PAD]], %[[FILTER]] : tensor<400x1024x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%[[FILL]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16> -// CHECK-NEXT: return %[[RESULT]] : tensor<400x1024x1024x1xf16> - -// ----- - -// CHECK-LABEL: func @linalg.conv_2D_padding_test2 -// CHECK-SAME: (%[[FILTER:.*]]: tensor<1x33x1x1xf16>, %[[INPUT:.*]]: tensor<400x1024x1024x1xf16>) -func.func @linalg.conv_2D_padding_test2(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<400x1024x1024x1xf16>) - -> tensor<400x1040x1024x1xf16> { - %0 = stablehlo.convolution(%arg1, %arg0) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64 - } : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1040x1024x1xf16>) - return %0 : tensor<400x1040x1024x1xf16> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16 -// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<400x1040x1024x1xf16> -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16> -// CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 8, 16, 0] high[0, 8, 16, 0] { -// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK-NEXT: tensor.yield %[[ZERO]] : f16 -// CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1040x1056x1xf16> -// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[PAD]], %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%[[FILL]] : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16> -// CHECK-NEXT: return %[[RESULT]] : tensor<400x1040x1024x1xf16> - -// ----- - -// CHECK-LABEL: func @depthwise_conv -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>, - %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 2 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array, - someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> - func.return %0 : tensor<2x3x4x6xf32> -} -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2, 3]] : tensor<2x2x1x6xf32> into tensor<24xf32> -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1, 2, 3]] output_shape [2, 2, 2, 3] : tensor<24xf32> into tensor<2x2x2x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x4x2x3xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> -// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[IN]], %[[EXPAND]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> -// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]] -// CHECK-SAME: [0], [1], [2], [3, 4] -// CHECK-SAME: : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv_with_padding -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv_with_padding( - %arg0: tensor<2x4x5x2xf32>, - %arg1: tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 2 : i64, - padding = dense<[[0, 0], [1, 1]]> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array, - someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> - func.return %0 : tensor<2x3x6x4xf32> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[PAD:.*]] = tensor.pad %[[IN]] low[0, 0, 1, 0] high[0, 0, 1, 0] { -// CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK: tensor.yield %[[ZERO]] : f32 -// CHECK } : tensor<2x4x5x2xf32> to tensor<2x4x7x2xf32> -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] -// CHECK-SAME: [0, 1, 2, 3] -// CHECK-SAME: : tensor<2x2x1x4xf32> into tensor<16xf32> -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] -// CHECK-SAME: [0, 1, 2, 3] -// CHECK-SAME: tensor<16xf32> into tensor<2x2x2x2xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x6x2x2xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32> -// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[PAD]], %[[EXPAND]] : tensor<2x4x7x2xf32>, tensor<2x2x2x2xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32> -// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]] -// CHECK-SAME: [0], [1], [2], [3, 4] -// CHECK-SAME: : tensor<2x3x6x2x2xf32> into tensor<2x3x6x4xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv_multiplier_1 -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>, - %arg1: tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 96 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> - func.return %0 : tensor<1x56x56x96xf32> -} -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x56x56x96xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] -// CHECK-SAME: [0], [1], [2, 3] -// CHECK-SAME: : tensor<3x3x1x96xf32> into tensor<3x3x96xf32> -// CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} -// CHECK-SAME: ins(%[[IN]], %[[RESHAPED_FILTER]] : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv_multiplier_1_with_padding -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv_multiplier_1_with_padding( - %arg0: tensor<1x113x113x96xf32>, - %arg1: tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 96 : i64, - padding = dense<[[1, 1], [2, 2]]> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> - func.return %0 : tensor<1x57x58x96xf32> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[PAD:.*]] = tensor.pad %[[IN]] low[0, 1, 2, 0] high[0, 1, 2, 0] { -// CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK: tensor.yield %[[ZERO]] : f32 -// CHECK } : tensor<1x113x113x96xf32> to tensor<1x115x117x96xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x58x96xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[INIT]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] -// CHECK-SAME: [0], [1], [2, 3] -// CHECK-SAME: : tensor<3x3x1x96xf32> into tensor<3x3x96xf32> -// CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} -// CHECK-SAME: ins(%[[PAD]], %[[RESHAPED_FILTER]] : tensor<1x115x117x96xf32>, tensor<3x3x96xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv1d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv1d(%arg0: tensor<1x10x8xf32>, - %arg1: tensor<3x1x16xf32>) -> tensor<1x10x16xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], - window = { - stride = [1], - pad = [[1, 1]], - lhs_dilate = [1], - rhs_dilate = [1], - reverse = [0]} { - batch_group_count = 1 : i64, - feature_group_count = 8 : i64, - someattr} : (tensor<1x10x8xf32>, tensor<3x1x16xf32>) -> tensor<1x10x16xf32> - func.return %0 : tensor<1x10x16xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_1d_nwc_wcm -// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[CONV]] -// CHECK: return %[[OUT]] - -// ----- - -// CHECK-LABEL: func @depthwise_conv1d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv1d_m1(%arg0: tensor<1x10x8xf32>, - %arg1: tensor<3x1x8xf32>) -> tensor<1x10x8xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], - window = { - stride = [1], - pad = [[1, 1]], - lhs_dilate = [1], - rhs_dilate = [1], - reverse = [0]} { - batch_group_count = 1 : i64, - feature_group_count = 8 : i64, - someattr} : (tensor<1x10x8xf32>, tensor<3x1x8xf32>) -> tensor<1x10x8xf32> - func.return %0 : tensor<1x10x8xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_1d_nwc_wc -// CHECK: return %[[CONV]] - -// ----- - -// CHECK-LABEL: func @depthwise_conv3d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv3d(%arg0: tensor<2x3x5x4x6xf32>, - %arg1: tensor<2x1x3x1x36xf32>) - -> tensor<2x3x13x4x36xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], - window = { - stride = [2, 1, 3], - pad = [[1, 2], [5, 3], [3, 5]], - lhs_dilate = [1, 1, 1], - rhs_dilate = [1, 1, 1], - reverse = [0, 0, 0]} { - batch_group_count = 1 : i64, - feature_group_count = 6 : i64, - someattr} : (tensor<2x3x5x4x6xf32>, tensor<2x1x3x1x36xf32>) - -> tensor<2x3x13x4x36xf32> - func.return %0 : tensor<2x3x13x4x36xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_3d_ndhwc_dhwcm -// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[CONV]] -// CHECK: return %[[OUT]] - -// ----- - -// CHECK-LABEL: func @depthwise_conv3d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv3d_m1(%arg0: tensor<2x3x5x4x6xf32>, - %arg1: tensor<2x1x3x1x6xf32>) - -> tensor<2x3x13x4x6xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], - window = { - stride = [2, 1, 3], - pad = [[1, 2], [5, 3], [3, 5]], - lhs_dilate = [1, 1, 1], - rhs_dilate = [1, 1, 1], - reverse = [0, 0, 0]} { - batch_group_count = 1 : i64, - feature_group_count = 6 : i64, - someattr} : (tensor<2x3x5x4x6xf32>, tensor<2x1x3x1x6xf32>) - -> tensor<2x3x13x4x6xf32> - func.return %0 : tensor<2x3x13x4x6xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_3d_ndhwc_dhwc -// CHECK: return %[[CONV]] diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_dot_prod.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_dot_prod.mlir deleted file mode 100644 index 1c4f3e2bfcbb..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_dot_prod.mlir +++ /dev/null @@ -1,276 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// Note: We need the canonicalization pass to deduplicate constants. This test -// does not rely on it to simplify arithmetic, etc. - -func.func @dot_general(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [2], - rhs_contracting_dimensions = [1] - >, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// The iterations are (Batch Dim, LHS Other Dim, RHS Other dim, Contracting Dim) -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)> -// Output is the iterators excluding contracting -// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK: func @dot_general( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// Only contracting dims are reductions -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[ARG2]], %[[ARG3]] : f32 -// CHECK: %[[SUM:.*]] = arith.addf %[[ARG4]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[SUM]] : f32 -// CHECK: } -> tensor - -// ----- - -func.func @dot_general_unsigned(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [2], - rhs_contracting_dimensions = [1] - >, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @dot_general_unsigned( -// CHECK: linalg.generic -// CHECK-SAME: ins({{.*}} : tensor, tensor) -// CHECK-SAME: outs({{.*}} : tensor) - -// ----- - -func.func @dot_general_complex(%arg0: tensor>, - %arg1: tensor>) -> tensor> { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [2], - rhs_contracting_dimensions = [1] - >, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor>, tensor>) -> tensor> - func.return %0 : tensor> -} - -// CHECK-LABEL: func @dot_general_complex( -// CHECK: linalg.generic -// CHECK: complex.mul -// CHECK: complex.add - -// ----- - -func.func @dot_general_multiple_batch_dimensions(%arg0: tensor<3x4x2x4xi32>, - %arg1: tensor<3x4x3x2xi32>) -> tensor<3x4x4x3xi32> { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [3]>, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor<3x4x2x4xi32>, tensor<3x4x3x2xi32>) -> tensor<3x4x4x3xi32> - return %0 : tensor<3x4x4x3xi32> -} - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d2)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> -// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> -// CHECK: func @dot_general_multiple_batch_dimensions -// CHECK-SAME: (%[[ARG0:.+]]: tensor<3x4x2x4xi32>, %[[ARG1:.+]]: tensor<3x4x3x2xi32>) -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3x4x2x4xi32>, tensor<3x4x3x2xi32>) -// CHECK-SAME: outs({{.*}} : tensor<3x4x4x3xi32>) -// CHECK-SAME: {someattr} - -// ----- - -func.func @dot_matmul(%arg0: tensor<2x3xf32>, - %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> { - %0 = "stablehlo.dot"(%arg0, %arg1) {someattr} - : (tensor<2x3xf32>, tensor<3x?xf32>) -> tensor<2x?xf32> - func.return %0 : tensor<2x?xf32> -} -// CHECK-LABEL: func @dot_matmul -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: {someattr} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>) - -// ----- - -func.func @dot_matmul_complex(%arg0: tensor<2x3xcomplex>, - %arg1: tensor<3x?xcomplex>) -> tensor<2x?xcomplex> { - %0 = "stablehlo.dot"(%arg0, %arg1) {someattr} - : (tensor<2x3xcomplex>, tensor<3x?xcomplex>) -> tensor<2x?xcomplex> - func.return %0 : tensor<2x?xcomplex> -} -// CHECK-LABEL: func @dot_matmul_complex( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xcomplex>, %[[ARG1:.*]]: tensor<3x?xcomplex>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: {someattr} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xcomplex>, tensor<3x?xcomplex>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xcomplex>) - -// ----- - -func.func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>, - %arg1: tensor<3x?xi8>) -> tensor<2x?xi32> { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xi8>, - tensor<3x?xi8>) -> tensor<2x?xi32> - func.return %0 : tensor<2x?xi32> -} -// CHECK-LABEL: func @dot_matmul_i8_i8_i32( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) - -// ----- - -func.func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>, - %arg1: tensor<3x?xi16>) -> tensor<2x?xi32> { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xi16>, - tensor<3x?xi16>) -> tensor<2x?xi32> - func.return %0 : tensor<2x?xi32> -} -// CHECK-LABEL: func @dot_matmul_i16_i16_i32( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) - -// ----- - -func.func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>, - %arg1: tensor<3x?xi32>) -> tensor<2x?xi32> { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xi32>, - tensor<3x?xi32>) -> tensor<2x?xi32> - func.return %0 : tensor<2x?xi32> -} -// CHECK-LABEL: func @dot_matmul_i32_i32_i32( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) - -// ----- - -func.func @dot_matvec(%arg0: tensor, - %arg1: tensor<3xf32>) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor, - tensor<3xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_matvec( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D0]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matvec -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) - -// ----- - -func.func @dot_vecmat(%arg0: tensor<3xf32>, - %arg1: tensor<3x?xf32>) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<3xf32>, - tensor<3x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_vecmat( -// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.vecmat -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3xf32>, tensor<3x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) - -// ----- - -func.func @dot_dot(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_dot( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = tensor.empty() -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.dot -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) - -// ----- - -func.func @dot_dot_unsigned(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_dot_unsigned( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = tensor.empty() -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}outs(%[[INIT]] -// CHECK: linalg.dot -// CHECK-SAME: ins(%{{.*}} : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_gather.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_gather.mlir deleted file mode 100644 index e585d5963a33..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_gather.mlir +++ /dev/null @@ -1,325 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func.func @gather( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather(%operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xi32>) -> tensor<1x8x8xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array, - someattr - } : (tensor<1x4x8xi32>, tensor<1x8x2xi32>) -> tensor<1x8x8xi32> - func.return %res : tensor<1x8x8xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: outs(%[[INIT]] : tensor<1x8x8xi32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[IDX1]], %[[C0]]] : tensor<1x8x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[IDX1]], %[[C1]]] : tensor<1x8x2xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.minsi %[[CLAMP0]], %[[C0]] -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[IN1:.+]] = arith.minsi %[[CLAMP1]], %[[C3]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]], %[[IDX2]]] : tensor<1x4x8xi32> -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK-DAG: return %[[RES]] - -// ----- - -// CHECK-LABEL: func.func @gather_unsigned_index( -func.func @gather_unsigned_index( - %operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xui32>) - -> tensor<1x8x8xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array, - someattr - } : (tensor<1x4x8xi32>, tensor<1x8x2xui32>) -> tensor<1x8x8xi32> - func.return %res : tensor<1x8x8xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK: %[[S0_INT:.+]] = tensor.extract {{.*}}[{{.*}}, %[[C0]]] -// CHECK: arith.index_castui %[[S0_INT]] : i32 to index -// CHECK: %[[S1_INT:.+]] = tensor.extract {{.*}}[{{.*}}, %[[C1]]] -// CHECK: arith.index_castui %[[S1_INT]] : i32 to index - -// ----- - -// CHECK-LABEL: func @gather_unsigned( -func.func @gather_unsigned(%operand : tensor<1x4x8xui32>, %start_indices : tensor<1x8x2xi32>) -> tensor<1x8x8xui32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor<1x4x8xui32>, tensor<1x8x2xi32>) -> tensor<1x8x8xui32> - func.return %res : tensor<1x8x8xui32> -} - -// CHECK: linalg.generic -// CHECK-SAME: outs(%{{.*}} : tensor<1x8x8xi32>) - -// ----- - -// CHECK-LABEL: func.func @gather_no_collapse( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather_no_collapse(%operand : tensor<6x3xi32>, %start_indices : tensor<5x2xi32>) -> tensor<5x4x2xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor<6x3xi32>, tensor<5x2xi32>) -> tensor<5x4x2xi32> - func.return %res : tensor<5x4x2xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x4x2xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<5x4x2xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C0]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C1]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[C2]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX1]] : index -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP1_1:.+]] = arith.minsi %[[CLAMP1]], %[[C1]] -// CHECK-DAG: %[[IN1:.+]] = arith.addi %[[CLAMP1_1]], %[[IDX2]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]]] : tensor<6x3xi32> -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - - -// ----- - -func.func @gather_max_offset(%operand : tensor, %start_indices : tensor<5x2xi32>) -> tensor<2x3x4x5xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [0, 1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor, tensor<5x2xi32>) -> tensor<2x3x4x5xi32> - func.return %res : tensor<2x3x4x5xi32> -} - -// CHECK-LABEL: func @gather_max_offset( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2x3x4x5xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<2x3x4x5xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX3]], %[[C0]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX3]], %[[C1]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[L0:.+]] = arith.subi %[[D0]], %[[C2]] -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[L0]] : index -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] -// CHECK-DAG: %[[L1:.+]] = arith.subi %[[D1]], %[[C3]] -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP1_1:.+]] = arith.minsi %[[CLAMP1]], %[[L1]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX0]] : index -// CHECK-DAG: %[[IN1:.+]] = arith.addi %[[CLAMP1_1]], %[[IDX1]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]], %[[IDX2]]] : tensor -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - -// ----- - -func.func @gather_reorder_start_index(%operand : tensor<6x3x2x7xi32>, %start_indices : tensor<5x4xi32>) -> tensor<5x2x4xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 2], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [3, 1, 2, 0] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor<6x3x2x7xi32>, tensor<5x4xi32>) -> tensor<5x2x4xi32> - func.return %res : tensor<5x2x4xi32> -} - -// CHECK-LABEL: func @gather_reorder_start_index( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[C5:.+]] = arith.constant 5 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x2x4xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<5x2x4xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C0]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C1]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[S2_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C2]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S2:.+]] = arith.index_cast %[[S2_INT]] : i32 to index -// CHECK-DAG: %[[S3_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C3]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S3:.+]] = arith.index_cast %[[S3_INT]] : i32 to index -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[C3]] -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP1_1:.+]] = arith.minsi %[[CLAMP1]], %[[C1]] -// CHECK-DAG: %[[CLAMP2:.+]] = arith.maxsi %[[S2]], %[[C0]] : index -// CHECK-DAG: %[[IN2:.+]] = arith.minsi %[[CLAMP2]], %[[C1]] -// CHECK-DAG: %[[CLAMP3:.+]] = arith.maxsi %[[S3]], %[[C0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.minsi %[[CLAMP3]], %[[C5]] -// CHECK-DAG: %[[IN1:.+]] = arith.addi %[[CLAMP1_1]], %[[IDX1]] : index -// CHECK-DAG: %[[IN3:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX2]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]], %[[IN2]], %[[IN3]]] : tensor<6x3x2x7xi32> -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - -// ----- - -// CHECK-LABEL: func.func @gather_implicit_trailing_dim( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather_implicit_trailing_dim(%operand : tensor, %start_indices : tensor<5x2xi32>) -> tensor<3x4x5x2xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 2, - offset_dims = [0, 1], - start_index_map = [0] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor, tensor<5x2xi32>) -> tensor<3x4x5x2xi32> - func.return %res : tensor<3x4x5x2xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<3x4x5x2xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<3x4x5x2xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX2]], %[[IDX3]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[L0:.+]] = arith.subi %[[D0]], %[[C3]] -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[L0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX0]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IDX1]]] : tensor -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - -// ----- - -// CHECK-LABEL: func.func @gather_non_static( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather_non_static(%operand : tensor, %start_indices : tensor) -> tensor<3x4x?xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [0, 1], - start_index_map = [0] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor, tensor) -> tensor<3x4x?xi32> - func.return %res : tensor<3x4x?xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[DYN_DIM:.+]] = tensor.dim %[[START_INDICES]], %[[C0]] -// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[DYN_DIM]]) : tensor<3x4x?xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<3x4x?xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX2]], %[[C0]]] : tensor -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[L0:.+]] = arith.subi %[[D0]], %[[C3]] -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[L0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX0]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IDX1]]] : tensor -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_pointwise.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_pointwise.mlir deleted file mode 100644 index e1421652e319..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_pointwise.mlir +++ /dev/null @@ -1,1489 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// RUN: iree-opt %s --iree-stablehlo-to-linalg="enable-primitive-ops=true" \ -// RUN: --split-input-file --canonicalize | \ -// RUN: FileCheck %s --check-prefix=CHECK-PRIMITIVE - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @float_add -// CHECK-PRIMITIVE-LABEL: func @float_add -func.func @float_add(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK-SAME: {someattr} - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] - // CHECK: linalg.yield %[[RESULT]] - - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.addf - %0 = "stablehlo.add"(%lhs, %rhs) {someattr} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @float_add_dynamic_encoding -// CHECK-PRIMITIVE-LABEL: func @float_add_dynamic_encoding -func.func @float_add_dynamic_encoding( - %lhs: tensor<2x?xf32, #stablehlo.type_extensions>, - %rhs: tensor<2x?xf32, #stablehlo.type_extensions>) - -> tensor<2x?xf32, #stablehlo.type_extensions> { - // CHECK: linalg.generic - // CHECK: arith.addf - // CHECK: linalg.yield - - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.addf - %0 = "stablehlo.add"(%lhs, %rhs) - : (tensor<2x?xf32, #stablehlo.type_extensions>, - tensor<2x?xf32, #stablehlo.type_extensions>) - -> tensor<2x?xf32, #stablehlo.type_extensions> - func.return %0 : tensor<2x?xf32, #stablehlo.type_extensions> -} - -// ----- - -// CHECK-LABEL: integer_add -// CHECK-PRIMITIVE-LABEL: integer_add -func.func @integer_add(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: addi - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: addi - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: complex_add -// CHECK-PRIMITIVE-LABEL: complex_add -func.func @complex_add(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.add - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.add - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex>, - tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @complex_atan2 -// CHECK-PRIMITIVE-LABEL: func @complex_atan2 -func.func @complex_atan2(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.atan2"(%lhs, %rhs) - : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.atan2 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.atan2 - func.return %tensor_result : tensor<2x2xcomplex> -} - - -// ----- - -// CHECK-LABEL: func @float_mul -// CHECK-PRIMITIVE-LABEL: func @float_mul -func.func @float_mul(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: mulf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: mulf - %0 = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_mul -// CHECK-PRIMITIVE-LABEL: func @integer_mul -func.func @integer_mul(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: muli - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: muli - %0 = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @complex_mul -// CHECK-PRIMITIVE-LABEL: func @complex_mul -func.func @complex_mul(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.mul - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.mul - %0 = "stablehlo.multiply"(%lhs, %rhs) - : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_remainder -// CHECK-PRIMITIVE-LABEL: func @float_remainder -func.func @float_remainder(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: remf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: remf - %0 = "stablehlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_remainder -// CHECK-PRIMITIVE-LABEL: func @integer_remainder -func.func @integer_remainder(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: arith.remsi - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.remsi - %0 = "stablehlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @population_count_integer -// CHECK-PRIMITIVE-LABEL: func @population_count_integer -func.func @population_count_integer(%lhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: math.ctpop - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.ctpop - %0 = "stablehlo.popcnt"(%lhs) : (tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @complex_sqrt -// CHECK-PRIMITIVE-LABEL: func @complex_sqrt -func.func @complex_sqrt(%operand: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.sqrt"(%operand) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.sqrt - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sqrt - func.return %tensor_result : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_rsqrt -// CHECK-PRIMITIVE-LABEL: func @float_rsqrt -func.func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { - %tensor_result = "stablehlo.rsqrt"(%operand) - : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: linalg.generic - // CHECK: rsqrt - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: rsqrt - func.return %tensor_result : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_rsqrt -// CHECK-PRIMITIVE-LABEL: func @complex_rsqrt -func.func @complex_rsqrt(%operand: tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.rsqrt"(%operand) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.rsqrt - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.rsqrt - func.return %tensor_result : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_cbrt -// CHECK-PRIMITIVE-LABEL: func @float_cbrt -func.func @float_cbrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { - %tensor_result = "stablehlo.cbrt"(%operand) - : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[RESULT:.+]] = math.cbrt %[[IN]] - // CHECK: linalg.yield %[[RESULT]] - - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.cbrt - func.return %tensor_result : tensor<2x2xf32> -} - -// ----- - - -// CHECK-LABEL: func @float_sub -// CHECK-PRIMITIVE-LABEL: func @float_sub -func.func @float_sub(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: subf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: subf - %0 = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_sub -// CHECK-PRIMITIVE-LABEL: func @integer_sub -func.func @integer_sub(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: subi - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: subi - %0 = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: complex_sub -// CHECK-PRIMITIVE-LABEL: complex_sub -func.func @complex_sub(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.sub - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sub - %0 = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex>, - tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_abs -// CHECK-PRIMITIVE-LABEL: func @float_abs -func.func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK-SAME: {someattr} - // CHECK: math.absf - // CHECK-PRIMITIVE: linalg.map { math.absf } - // CHECK-PRIMITIVE-SAME: ins( - // CHECK-PRIMITIVE-SAME: outs( - // CHECK-PRIMITIVE-SAME: {someattr} - %0 = "stablehlo.abs"(%arg0) {someattr} : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @float_exp -// CHECK-PRIMITIVE-LABEL: func @float_exp -func.func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: exp - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: exp - %0 = "stablehlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_exp -func.func @complex_exp(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.exp - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.exp - %0 = "stablehlo.exponential"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_expm1 -func.func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: expm1 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: expm1 - %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_expm1 -func.func @complex_expm1(%arg0: tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.expm1 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.expm1 - %0 = "stablehlo.exponential_minus_one"(%arg0) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_log -func.func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.log - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.log - %0 = "stablehlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_log -func.func @complex_log(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.log - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.log - %0 = "stablehlo.log"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_log1p -// CHECK-PRIMITIVE-LABEL: func @float_log1p -func.func @float_log1p(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.log1p - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.log1p - %0 = "stablehlo.log_plus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_log1p -// CHECK-PRIMITIVE-LABEL: func @complex_log1p -func.func @complex_log1p(%arg0: tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.log1p - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.log1p - %0 = "stablehlo.log_plus_one"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_logistic -// CHECK-PRIMITIVE-LABEL: func @float_logistic -func.func @float_logistic(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: %[[C1:.*]] = arith.constant 1.{{.*}}e+00 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32): - // CHECK: %[[NEG_ARG:.*]] = arith.negf %[[ARG]] - // CHECK: %[[EXP_NEG_ARG:.*]] = math.exp %[[NEG_ARG]] - // CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = arith.addf %[[EXP_NEG_ARG]], %[[C1]] - // CHECK: %[[RESULT:.*]] = arith.divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]] - // CHECK: linalg.yield %[[RESULT]] - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.negf - // CHECK-PRIMITIVE: math.exp - // CHECK-PRIMITIVE: arith.addf - // CHECK-PRIMITIVE: arith.divf - %0 = "stablehlo.logistic"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_logistic -func.func @complex_logistic(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[ARG:.*]]: complex, %{{.*}}: complex): - // CHECK: %[[NEG_ARG:.*]] = complex.neg %[[ARG]] - // CHECK: %[[EXP_NEG_ARG:.*]] = complex.exp %[[NEG_ARG]] - // CHECK: %[[CC1:.*]] = complex.create %[[C1]], %[[C0]] : complex - // CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = complex.add %[[EXP_NEG_ARG]], %[[CC1]] - // CHECK: %[[RESULT:.*]] = complex.div %[[CC1]], %[[ONE_ADD_EXP_NEG_ARG]] - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.logistic"(%arg0) : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_ceil -// CHECK-PRIMITIVE-LABEL: func @float_ceil -func.func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.ceil - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.ceil - %0 = "stablehlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @floor -// CHECK-PRIMITIVE-LABEL: func @floor -func.func @floor(%input: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.floor - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.floor - %0 = "stablehlo.floor"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @float_neg -// CHECK-PRIMITIVE-LABEL: func @float_neg -func.func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: negf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: negf - %0 = "stablehlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_neg -// CHECK-PRIMITIVE-LABEL: func @complex_neg -func.func @complex_neg(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.neg - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.neg - %0 = "stablehlo.negate"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @complex_sign -// CHECK-PRIMITIVE-LABEL: func @complex_sign -func.func @complex_sign( - %arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.sign - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sign - %0 = "stablehlo.sign"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_tanh -// CHECK-PRIMITIVE-LABEL: func @float_tanh -func.func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: tanh - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: tanh - %0 = "stablehlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_tanh -// CHECK-PRIMITIVE-LABEL: func @complex_tanh -func.func @complex_tanh(%operand: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.tanh"(%operand) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.tanh - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.tanh - func.return %tensor_result : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @integer_and -// CHECK-PRIMITIVE-LABEL: func @integer_and -func.func @integer_and(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: and - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: and - %0 = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @integer_or -// CHECK-PRIMITIVE-LABEL: func @integer_or -func.func @integer_or(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: or - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: or - %0 = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @integer_xor -// CHECK-PRIMITIVE-LABEL: func @integer_xor -func.func @integer_xor(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: xor - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: xor - %0 = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @count_leading_zeros -// CHECK-PRIMITIVE-LABEL: func @count_leading_zeros -func.func @count_leading_zeros(%lhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: math.ctlz - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.ctlz - %0 = "stablehlo.count_leading_zeros"(%lhs) : (tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: unsigned_convert -func.func @unsigned_convert(%in: tensor<2x2xui32>) -> tensor<2x2xui64> { - // CHECK: linalg.generic - // CHECK: arith.extui - %0 = "stablehlo.convert"(%in) : (tensor<2x2xui32>) -> tensor<2x2xui64> - func.return %0 : tensor<2x2xui64> -} - -// ----- - -// CHECK-LABEL: func @float_cmp -// CHECK-PRIMITIVE-LABEL: func @float_cmp -func.func @float_cmp(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpf oeq, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.cmpf - -// ----- - -// CHECK-LABEL: func @float_cmp_ne -// CHECK-PRIMITIVE-LABEL: func @float_cmp_ne -func.func @float_cmp_ne(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpf une, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.cmpf - -// ----- - -// CHECK-LABEL: func @float_cmp_totalorder -// CHECK-PRIMITIVE-LABEL: func @float_cmp_totalorder -func.func @float_cmp_totalorder(%lhs: tensor<2x2xbf16>, - %rhs: tensor<2x2xbf16>) -> (tensor<2x2xi1>) { - %0 = "stablehlo.compare"(%lhs, %rhs) { - comparison_direction = #stablehlo, - compare_type = #stablehlo - } : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i16 -// CHECK-DAG: %[[C32767:.*]] = arith.constant 32767 : i16 -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: bf16, %[[RHS_IN:.*]]: bf16, %{{.*}}: i1): -// CHECK-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16 -// CHECK-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16 -// CHECK-NEXT: %[[LHS_SUB:.*]] = arith.subi %[[C32767]], %[[LHS_INT]] : i16 -// CHECK-NEXT: %[[LHS_SELECT:.*]] = arith.select %[[LHS_CMP]], %[[LHS_SUB]], %[[LHS_INT]] : i16 -// CHECK-NEXT: %[[RHS_INT:.*]] = arith.bitcast %[[RHS_IN]] : bf16 to i16 -// CHECK-NEXT: %[[RHS_CMP:.*]] = arith.cmpi slt, %[[RHS_INT]], %[[C0]] : i16 -// CHECK-NEXT: %[[RHS_SUB:.*]] = arith.subi %[[C32767]], %[[RHS_INT]] : i16 -// CHECK-NEXT: %[[RHS_SELECT:.*]] = arith.select %[[RHS_CMP]], %[[RHS_SUB]], %[[RHS_INT]] : i16 -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpi slt, %[[LHS_SELECT]], %[[RHS_SELECT]] : i16 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// CHECK-PRIMITIVE-DAG: %[[C0:.*]] = arith.constant 0 : i16 -// CHECK-PRIMITIVE-DAG: %[[C32767:.*]] = arith.constant 32767 : i16 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE-SAME: ins( -// CHECK-PRIMITIVE-SAME: outs( -// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16) { -// CHECK-PRIMITIVE-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16 -// CHECK-PRIMITIVE-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[LHS_SUB:.*]] = arith.subi %[[C32767]], %[[LHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[LHS_SELECT:.*]] = arith.select %[[LHS_CMP]], %[[LHS_SUB]], %[[LHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_INT:.*]] = arith.bitcast %[[RHS_IN]] : bf16 to i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_CMP:.*]] = arith.cmpi slt, %[[RHS_INT]], %[[C0]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_SUB:.*]] = arith.subi %[[C32767]], %[[RHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_SELECT:.*]] = arith.select %[[RHS_CMP]], %[[RHS_SUB]], %[[RHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RESULT:.*]] = arith.cmpi slt, %[[LHS_SELECT]], %[[RHS_SELECT]] : i16 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[RESULT]] : i1 - -// ----- - -// CHECK-LABEL: func @int_cmp -// CHECK-PRIMITIVE-LABEL: func @int_cmp -func.func @int_cmp(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) - func.return %0 : tensor<2x2xi1> -} -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.cmpi - -// ----- - -// CHECK-LABEL: func @complex_cmp_eq -// CHECK-PRIMITIVE-LABEL: func @complex_cmp_eq -func.func @complex_cmp_eq(%lhs: tensor<2xcomplex>, - %rhs: tensor<2xcomplex>) -> tensor<2xi1> { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xi1>) - func.return %0 : tensor<2xi1> -} -// CHECK: tensor.empty() : tensor<2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): -// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: complex.eq - -// ----- - -// CHECK-LABEL: func @complex_cmp_neq -// CHECK-PRIMITIVE-LABEL: func @complex_cmp_neq -func.func @complex_cmp_neq(%lhs: tensor<2xcomplex>, - %rhs: tensor<2xcomplex>) -> tensor<2xi1> { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xi1>) - func.return %0 : tensor<2xi1> -} -// CHECK: tensor.empty() : tensor<2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): -// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: complex.neq - -// ----- - -// CHECK-LABEL: func @float_cos -// CHECK-PRIMITIVE-LABEL: func @float_cos -func.func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.cos - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.cos - %0 = "stablehlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_cos -// CHECK-PRIMITIVE-LABEL: func @complex_cos -func.func @complex_cos(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.cos - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.cos - %0 = "stablehlo.cosine"(%arg0) : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_sin -// CHECK-PRIMITIVE-LABEL: func @float_sin -func.func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.sin - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.sin - %0 = "stablehlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_sin -// CHECK-PRIMITIVE-LABEL: func @complex_sin -func.func @complex_sin(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.sin - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sin - %0 = "stablehlo.sine"(%arg0) : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @is_finte -// CHECK-PRIMITIVE-LABEL: func @is_finte -func.func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { - %0 = "stablehlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK: %[[POS_INF:.+]] = arith.constant 0x7F800000 : f32 -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32 -// CHECK-NEXT: %[[ABS_X:.+]] = math.absf %[[OPERAND_IN]] : f32 -// CHECK-NEXT: %[[RESULT:.+]] = arith.cmpf one, %[[ABS_X]], %[[POS_INF]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: math.absf -// CHECK-PRIMITIVE: arith.cmpf - -// ----- - -// CHECK-LABEL: func @round_nearest_even -// CHECK-PRIMITIVE-LABEL: func @round_nearest_even -func.func @round_nearest_even(%val: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[ROUND:.+]] = math.roundeven %[[IN]] - // CHECK: linalg.yield %[[ROUND]] - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.roundeven - %0 = "stablehlo.round_nearest_even"(%val) : (tensor<2x2xf32>) -> (tensor<2x2xf32>) - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @round -// CHECK-PRIMITIVE-LABEL: func @round -func.func @round(%val: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[ROUND:.+]] = math.round %[[IN]] - // CHECK: linalg.yield %[[ROUND]] - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.round - %0 = "stablehlo.round_nearest_afz"(%val) : (tensor<2x2xf32>) -> (tensor<2x2xf32>) - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @select -func.func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "stablehlo.select"(%pred, %lhs, %rhs) - : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) - func.return %0 : tensor<2x2xf32> -} -// CHECK: tensor.empty() : tensor<2x2xf32> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @select -// CHECK-PRIMITIVE: tensor.empty() : tensor<2x2xf32> -// CHECK-PRIMITIVE: linalg.map { arith.select } -// CHECK-PRIMITIVE-SAME: ins( -// CHECK-PRIMITIVE-SAME: outs( - -// ----- - -// CHECK-DAG: #[[SCALAR_MAP:.*]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @select_scalar_pred_dyn -// CHECK-SAME: (%[[PRED:.*]]: tensor, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>) -func.func @select_scalar_pred_dyn(%pred : tensor, %lhs: tensor<2x?xf32>, %rhs: tensor<2x?xf32>) -> tensor<2x?xf32> { - %0 = "stablehlo.select"(%pred, %lhs, %rhs) {someattr} : (tensor, tensor<2x?xf32>, tensor<2x?xf32>) -> (tensor<2x?xf32>) - func.return %0 : tensor<2x?xf32> -} -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[LHS]], %[[C1]] -// CHECK-DAG: %[[DST:.*]] = tensor.empty(%[[DIM]]) -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[SCALAR_MAP]], #[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[PRED]], %[[LHS]], %[[RHS]] : tensor, tensor<2x?xf32>, tensor<2x?xf32>) -// CHECK-SAME: outs(%[[DST]] : tensor<2x?xf32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[PRED_:.*]]: i1, %[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %{{.*}}: f32): -// CHECK: %[[RES:.*]] = arith.select %[[PRED_]], %[[LHS_]], %[[RHS_]] : f32 -// CHECK: linalg.yield %[[RES]] - -// CHECK-PRIMITIVE-LABEL: func @select_scalar_pred_dyn -// CHECK-PRIMITIVE-SAME: (%[[PRED:.*]]: tensor, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>) -// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-PRIMITIVE-DAG: %[[DIM:.*]] = tensor.dim %[[LHS]], %[[C1]] -// CHECK-PRIMITIVE-DAG: %[[DST:.*]] = tensor.empty(%[[DIM]]) -// CHECK-PRIMITIVE-DAG: %[[PRED_ELEM:.*]] = tensor.extract %[[PRED]] -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>) -// CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>) -// CHECK-PRIMITIVE-SAME: {someattr} -// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) { -// CHECK-PRIMITIVE: %[[RES:.*]] = arith.select %[[PRED_ELEM]], %[[LHS_]], %[[RHS_]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[RES]] - -// ----- - -// CHECK-LABEL: func @select_mixed -func.func @select_mixed(%pred: tensor<2x?xi1>, %lhs: tensor, - %rhs: tensor<2x2xf32>) -> tensor { - %0 = "stablehlo.select"(%pred, %lhs, %rhs) - : (tensor<2x?xi1>, tensor, tensor<2x2xf32>) -> (tensor) - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast_convert -func.func @bitcast_convert(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> - func.return %result : tensor<2x2xf32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK-LABEL: func @bitcast_convert_dynamic -func.func @bitcast_convert_dynamic(%input: tensor) -> tensor { - %result = "stablehlo.bitcast_convert"(%input) : (tensor) -> tensor - func.return %result : tensor -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @bitcast_convert_expand -func.func @bitcast_convert_expand(%input: tensor<6xi32>) -> tensor<6x4xi8> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<6xi32>) -> tensor<6x4xi8> - func.return %result : tensor<6x4xi8> -} - -// CHECK: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: tensor.empty() : tensor<6x4xi8> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "parallel"]} -// CHECK: ^bb0(%[[IN:.*]]: i32, %[[OUT:.*]]: i8): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i32 -// CHECK: %[[SHIFT:.*]] = arith.shrui %[[IN]], %[[AMT]] : i32 -// CHECK: %[[TRUNC:.*]] = arith.trunci %[[SHIFT]] : i32 to i8 -// CHECK: linalg.yield %[[TRUNC]] : i8 -// CHECK: } -> tensor<6x4xi8> -// CHECK: return %[[RESULT]] : tensor<6x4xi8> - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @bitcast_convert_contract -func.func @bitcast_convert_contract(%input: tensor<7x4xi8>) -> tensor<7xi32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<7x4xi8>) -> tensor<7xi32> - func.return %result : tensor<7xi32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<7xi32> -// CHECK: linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<7xi32>) -> tensor<7xi32> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "reduction"]} -// CHECK: ^bb0(%[[IN:.*]]: i8, %[[OUT:.*]]: i32): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i3 -// CHECK: %[[EXT:.*]] = arith.extui %[[IN]] : i8 to i32 -// CHECK: %[[SHIFT:.*]] = arith.shli %[[EXT]], %[[AMT]] : i32 -// CHECK: %[[OR:.*]] = arith.ori %[[SHIFT]], %[[OUT]] : i32 -// CHECK: linalg.yield %[[OR]] : i32 -// CHECK: } -> tensor<7xi32> -// CHECK: return %[[RESULT]] : tensor<7xi32> - -// ----- - -// CHECK-LABEL: signed_divide -func.func @signed_divide(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK-DAG: %[[VAL_7:.*]] = arith.constant -1 : i32 - // CHECK-DAG: %[[VAL_8:.*]] = arith.constant -2147483648 : i32 - // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 1 : i32 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32): - // CHECK: %[[VAL_11:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_9]] : i32 - // CHECK: %[[VAL_13:.*]] = arith.cmpi eq, %[[VAL_4]], %[[VAL_8]] : i32 - // CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_7]] : i32 - // CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_13]], %[[VAL_15]] : i1 - // CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_11]], %[[VAL_16]] : i1 - // CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_10]], %[[VAL_5]] : i32 - // CHECK: %[[VAL_19:.*]] = arith.divsi %[[VAL_4]], %[[VAL_18]] : i32 - // CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_16]], %[[VAL_8]], %[[VAL_19]] : i32 - // CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_11]], %[[VAL_7]], %[[VAL_20]] : i32 - // CHECK: linalg.yield %[[VAL_21]] : i32 - %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: unsigned_divide -func.func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { - // CHECK-DAG: %[[VAL_9:.*]] = arith.constant -1 : i32 - // CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[VAL_12:.*]] = arith.constant 1 : i32 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32): - // CHECK: %[[VAL_13:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_11]] : i32 - // CHECK: %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_12]], %[[VAL_7]] : i32 - // CHECK: %[[VAL_15:.*]] = arith.divui %[[VAL_6]], %[[VAL_14]] : i32 - // CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_13]], %[[VAL_9]], %[[VAL_15]] : i32 - // CHECK: linalg.yield %[[VAL_16]] : i32 - %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> - func.return %0 : tensor<2x2xui32> -} - -// ----- - -// CHECK-LABEL: complex_divide -func.func @complex_divide(%lhs: tensor<2xcomplex>, - %rhs: tensor<2xcomplex>) -> tensor<2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.div - %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xcomplex> - func.return %0 : tensor<2xcomplex> -} - -// ----- - -func.func @shift_left(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %result = "stablehlo.shift_left"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %result : tensor<2x2xi32> -} -// CHECK-LABEL: func @shift_left -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 -// CHECK-DAG: %[[BITS:.*]] = arith.constant 32 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): -// CHECK-DAG: %[[SHIFT:.*]] = arith.shli %[[LHS]], %[[RHS]] : i32 -// CHECK-DAG: %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]] -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[ZERO]] -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -func.func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %result : tensor<2x2xi32> -} -// CHECK-LABEL: func @shift_right_arithmetic -// CHECK-DAG: %[[BITS:.*]] = arith.constant 32 -// CHECK-DAG: %[[MAX_SHIFT:.*]] = arith.constant 31 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): -// CHECK-DAG: %[[SHIFT:.*]] = arith.shrsi %[[LHS]], %[[RHS]] : i32 -// CHECK-DAG: %[[MAX_SHIFTED:.*]] = arith.shrsi %[[LHS]], %[[MAX_SHIFT]] : i32 -// CHECK-DAG: %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]] -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[MAX_SHIFTED]] -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -func.func @shift_right_logical(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %result = "stablehlo.shift_right_logical"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %result : tensor<2x2xi32> -} -// CHECK-LABEL: func @shift_right_logical -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 -// CHECK-DAG: %[[BITS:.*]] = arith.constant 32 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): -// CHECK-DAG: %[[SHIFT:.*]] = arith.shrui %[[LHS]], %[[RHS]] : i32 -// CHECK-DAG: %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]] -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[ZERO]] -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_basic -func.func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "ijk,ikm->ijm", someattr}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> - func.return %0 : tensor<3x4x6xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x5x6xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<3x4x6xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5xf32>, tensor<3x5x6xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<3x4x6xf32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @float_pow -func.func @float_pow(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = math.powf %[[ARG0]], %[[ARG1]] - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_pow -func.func @complex_pow(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: complex - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: complex - // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = complex.pow %[[ARG0]], %[[ARG1]] - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xcomplex>, - tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @integer_pow -func.func @integer_pow(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 - // CHECK: %[[FOR_RESULT:[a-zA-Z0-9_]*]]:3 = scf.for {{.*}} to %c6 step %c1 - // CHECK-SAME: iter_args( - // CHECK-SAME: %[[ITER0:.*]] = %c1 - // CHECK-SAME: %[[ITER1:.*]] = %[[ARG0]], - // CHECK-SAME: %[[ITER2:.*]] = %[[ARG1]] - // CHECK-SAME: ) -> (i32, i32, i32) { - // CHECK: %[[AND:[a-zA-Z0-9_]*]] = arith.andi %[[ITER2]], %c1 - // CHECK: %[[COND:[a-zA-Z0-9_]*]] = arith.cmpi eq, %[[AND]], %c1 - // CHECK: %[[MUL:[a-zA-Z0-9_]*]] = arith.muli %[[ITER0]], %[[ITER1]] - // CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = arith.select %[[COND]], %[[MUL]], %[[ITER0]] - // CHECK: %[[BASE:[a-zA-Z0-9_]*]] = arith.muli %[[ITER1]], %[[ITER1]] - // CHECK: %[[EXP:[a-zA-Z0-9_]*]] = arith.shrui %[[ITER2]], %c1 - // CHECK: scf.yield %[[ACCUM]], %[[BASE]], %[[EXP]] - // CHECK: %[[RHS_PARITY:.*]] = arith.remsi %[[ARG1]], %c2 - // CHECK: %[[RHS_EVEN:.*]] = arith.cmpi eq, %[[RHS_PARITY]], %c0 - // CHECK: %[[RHS_NEG:.*]] = arith.cmpi slt, %[[ARG1]], %c0 - // CHECK: %[[LHS_ONE:.*]] = arith.cmpi eq, %[[ARG0]], %c1 - // CHECK: %[[LHS_NEG_ONE:.*]] = arith.cmpi eq, %[[ARG0]], %c-1 - // CHECK: %[[VAL5:.*]] = arith.extui %[[LHS_ONE]] : i1 to i32 - // CHECK: %[[VAL6:.*]] = arith.select %[[RHS_EVEN]], %c1{{.*}}, %c-1 - // CHECK: %[[VAL7:.*]] = arith.select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]] - // CHECK: %[[RESULT:.*]] = arith.select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]#0 - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: @real_real -// CHECK-SAME: (%[[ARG0:.*]]: -func.func @real_real(%arg0: tensor) -> tensor { - %1 = "stablehlo.real"(%arg0) : (tensor) -> (tensor) - // CHECK: return %[[ARG0]] - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @imag_real -func.func @imag_real(%arg0: tensor) -> tensor { - %1 = "stablehlo.imag"(%arg0) : (tensor) -> (tensor) - // CHECK: %[[CST:.*]] = arith.constant 0 - // CHECK: linalg.generic - // CHECK: yield %[[CST]] - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: func @minf -func.func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "stablehlo.minimum"(%lhs, %rhs) {someattr} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} -// CHECK: tensor.empty() : tensor<2x2xf32> -// CHECK: linalg.generic -// CHECK-SAME: {someattr} -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.minimumf %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.minimumf - -// ----- - -// CHECK-LABEL: func @maxi -func.func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "stablehlo.maximum"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} -// CHECK: tensor.empty() : tensor<2x2xi32> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxsi %[[LHS_IN]], %[[RHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.maxsi - -// ----- - -// CHECK-LABEL: func @maxu -func.func @maxu(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { - %0 = "stablehlo.maximum"(%lhs, %rhs) - : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> - func.return %0 : tensor<2x2xui32> -} -// CHECK: tensor.empty() : tensor<2x2xi32> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxui %[[LHS_IN]], %[[RHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.maxui - -// ----- - -// CHECK-LABEL: func @maxi1 -func.func @maxi1(%lhs: tensor, %rhs: tensor) -> tensor { - %0 = "stablehlo.maximum"(%lhs, %rhs) - : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i1, %[[RHS_IN:.*]]: i1, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxui %[[LHS_IN]], %[[RHS_IN]] : i1 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.maxui - - -// ----- - -// CHECK-LABEL: @clamp_static -// CHECK-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32> -func.func @clamp_static(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: %[[INIT:.*]] = tensor.empty - // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) - // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32): - // CHECK: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 - // CHECK: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 - // CHECK: linalg.yield %[[MIN]] - // CHECK: } -> tensor<4xf32> - // CHECK: return %[[RESULT]] : tensor<4xf32> - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor<4xf32>, - tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// CHECK-PRIMITIVE-LABEL: @clamp_static -// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32> - -// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty -// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) -// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32) -// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 -// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[MIN]] -// CHECK-PRIMITIVE: return %[[RESULT]] : tensor<4xf32> - -// ----- - -// CHECK-LABEL: @clamp_dynamic -// CHECK-SAME: %[[LB:.*]]: tensor, %[[X:.*]]: tensor, %[[UB:.*]]: tensor -func.func @clamp_dynamic(%lb : tensor, %x : tensor, %ub : tensor) - -> tensor { - // CHECK: %[[INIT:.*]] = tensor.empty - // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor, tensor, tensor) outs(%[[INIT]] : tensor) - // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32): - // CHECK: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 - // CHECK: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 - // CHECK: linalg.yield %[[MIN]] - // CHECK: } -> tensor - // CHECK: return %[[RESULT]] : tensor - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-PRIMITIVE-LABEL: @clamp_dynamic -// CHECK-PRIMITIVE: linalg.map - -// ----- - -func.func @clamp_mixed(%lb : tensor<4xf32>, %x : tensor, %ub : tensor) - -> tensor { - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @clamp_mixed -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @clamp_mixed -// CHECK-PRIMITIVE: linalg.map - -// ----- - -func.func @clamp_scalar(%lb : tensor, %x : tensor, %ub : tensor) - -> tensor { - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @clamp_scalar -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @clamp_scalar -// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor, %[[X:.*]]: tensor, %[[UB:.*]]: tensor - -// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.empty -// CHECK-PRIMITIVE-DAG: %[[SCALAR_LB:.*]] = tensor.extract %[[LB]] -// CHECK-PRIMITIVE-DAG: %[[SCALAR_UB:.*]] = tensor.extract %[[UB]] -// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[X]] : tensor) outs(%[[INIT]] : tensor) -// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32) -// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 -// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[MIN]] -// CHECK-PRIMITIVE: return %[[RESULT]] - - -// ----- - -func.func @clamp_scalar_mixed(%lb : tensor, %x : tensor, %ub : tensor) - -> tensor { - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @clamp_scalar_mixed -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @clamp_scalar_mixed -// CHECK-PRIMITIVE: linalg.map - -// ----- - -// CHECK-LABEL: func @reduce_precision( -// CHECK-DAG: %[[C2:.*]] = arith.constant 1048576 : i32 -// CHECK-DAG: %[[C_21:.*]] = arith.constant 20 : i32 -// CHECK-DAG: %[[C3:.*]] = arith.constant 524287 : i32 -// CHECK-DAG: %[[C4:.*]] = arith.constant -1048576 : i32 -// CHECK-DAG: %[[C5:.*]] = arith.constant 2139095040 : i32 -// CHECK-DAG: %[[C6:.*]] = arith.constant 1090519040 : i32 -// CHECK-DAG: %[[C7:.*]] = arith.constant 1040187392 : i32 -// CHECK-DAG: %[[C8:.*]] = arith.constant -2147483648 : i32 -// CHECK-DAG: %[[C9:.*]] = arith.constant 2147483647 : i32 -// CHECK: linalg.generic -// CHECK: %[[X_AS_INT:.*]] = arith.bitcast %[[IN:.*]] : f32 to i32 -// CHECK: %[[ABS_X:.*]] = arith.andi %[[X_AS_INT]], %[[C9]] -// CHECK: %[[IS_NAN:.*]] = arith.cmpi ugt, %[[ABS_X]], %[[C5]] -// CHECK: %[[MASKED:.*]] = arith.andi %[[X_AS_INT]], %[[C2]] : i32 -// CHECK: %[[V0:.*]] = arith.shrui %[[MASKED]], %[[C_21]] : i32 -// CHECK: %[[V1:.*]] = arith.addi %[[V0]], %[[C3]] : i32 -// CHECK: %[[V2:.*]] = arith.addi %[[X_AS_INT]], %[[V1]] : i32 -// CHECK: %[[V3:.*]] = arith.andi %[[V2]], %[[C4]] : i32 -// CHECK: %[[V4:.*]] = arith.andi %[[V3]], %[[C5]] : i32 -// CHECK: %[[V5:.*]] = arith.cmpi ugt, %[[V4]], %[[C6]] : i32 -// CHECK: %[[V6:.*]] = arith.cmpi ule, %[[V4]], %[[C7]] : i32 -// CHECK: %[[V7:.*]] = arith.andi %[[V3]], %[[C8]] : i32 -// CHECK: %[[V8:.*]] = arith.ori %[[V7]], %[[C5]] : i32 -// CHECK: %[[V9:.*]] = arith.select %[[V5]], %[[V8]], %[[V3]] : i32 -// CHECK: %[[V10:.*]] = arith.select %[[V6]], %[[V7]], %[[V9]] : i32 -// CHECK: %[[CONVERTED:.*]] = arith.bitcast %[[V10]] : i32 to f32 -// CHECK: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[IN]], %[[CONVERTED]] -// CHECK: linalg.yield %[[RESULT]] - -// CHECK-PRIMITIVE-LABEL: func @reduce_precision( -// CHECK-PRIMITIVE: linalg.map -func.func @reduce_precision(%arg0: tensor<1x2x3x4xf32>) - -> tensor<1x2x3x4xf32> { - %0 = "stablehlo.reduce_precision"(%arg0) {exponent_bits=3:i32, mantissa_bits=3:i32} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> - return %0 : tensor<1x2x3x4xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_not -// CHECK-SAME: (%[[ARG:.+]]: tensor<2x2xi32>) -// CHECK-PRIMITIVE-LABEL: func @integer_not -// CHECK-PRIMITIVE-SAME: (%[[ARG:.+]]: tensor<2x2xi32>) -func.func @integer_not(%arg: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: %[[CST_N1:.+]] = arith.constant -1 : i32 - // CHECK: linalg.generic - // CHECK: (%[[IN:.+]]: i32, %{{.+}}: i32) - // CHECK: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32 - // CHECK: linalg.yield %[[V_NOT]] : i32 - // CHECK-PRIMITIVE: %[[CST_N1:.+]] = arith.constant -1 : i32 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: (%[[IN:.+]]: i32) - // CHECK-PRIMITIVE: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32 - // CHECK-PRIMITIVE: linalg.yield %[[V_NOT]] : i32 - %0 = "stablehlo.not"(%arg) : (tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @float_complex -// CHECK-SAME: (%[[LHS:.+]]: tensor<2x2xf32>, %[[RHS:.+]]: tensor<2x2xf32>) -// CHECK-PRIMITIVE-LABEL: func @float_complex -// CHECK-PRIMITIVE-SAME: (%[[LHS:.+]]: tensor<2x2xf32>, %[[RHS:.+]]: tensor<2x2xf32>) -func.func @float_complex(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xcomplex> { - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[LHS]], %[[RHS]] - // CHECK: (%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %{{.+}}: complex - // CHECK: %[[RES:.+]] = complex.create %[[IN0]], %[[IN1]] : complex - // CHECK: linalg.yield %[[RES]] : complex - // CHECK-PRIMITIVE: %[[INIT:.+]] = tensor.empty() : tensor<2x2xcomplex> - // CHECK-PRIMITIVE: linalg.map { complex.create } ins(%[[LHS]], %[[RHS]] : tensor<2x2xf32>, tensor<2x2xf32>) - // CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor<2x2xcomplex>) - %0 = "stablehlo.complex"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_random.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_random.mlir deleted file mode 100644 index b56d63bf9c64..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_random.mlir +++ /dev/null @@ -1,693 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// CHECK-LABEL: func @rng_uniform_1d -func.func @rng_uniform_1d(%min: tensor, %max: tensor) -> tensor<10xf32> { - %shape = arith.constant dense<[10]> : tensor<1xi32> - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<1xi32>) -> tensor<10xf32> - func.return %0 : tensor<10xf32> -} -// CHECK-DAG: ^{{.+}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL2_CAST:.+]] = arith.uitofp %[[VAL2]] : i32 to f32 -// CHECK-DAG: %[[VAL4:.+]] = arith.mulf %[[VAL2_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addf %[[VAL4]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL5]] : f32 -// CHECK-NEXT: -> tensor<10xf32> - -// ----- - -// CHECK-LABEL: func @rng_uniform_2d -func.func @rng_uniform_2d(%min: tensor, %max: tensor) -> tensor<3x3xf32> { - %shape = arith.constant dense<[3, 3]> : tensor<2xi32> - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<2xi32>) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} -// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index -// CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[VAL3:.+]] = arith.addi %[[IDX1_CAST]], %[[VAL2]] : i32 -// CHECK-DAG: %[[VAL4:.+]] = arith.muli %[[VAL3]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addi %[[VAL4]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL5_CAST:.+]] = arith.uitofp %[[VAL5]] : i32 to f32 -// CHECK-DAG: %[[VAL6:.+]] = arith.mulf %[[VAL5_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL7:.+]] = arith.addf %[[VAL6]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL7]] : f32 -// CHECK-NEXT: -> tensor<3x3xf32> - -// ----- - -// CHECK-LABEL: func @rng_uniform_3d -func.func @rng_uniform_3d(%min: tensor, %max: tensor) -> tensor<2x2x2xf32> { - %shape = arith.constant dense<[2, 2, 2]> : tensor<3xi32> - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<3xi32>) -> tensor<2x2x2xf32> - func.return %0 : tensor<2x2x2xf32> -} -// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index -// CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index -// CHECK-DAG: %[[IDX2_CAST:.+]] = arith.index_cast %[[IDX2]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[VAL3:.+]] = arith.addi %[[IDX1_CAST]], %[[VAL2]] : i32 -// CHECK-DAG: %[[VAL4:.+]] = arith.muli %[[VAL3]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addi %[[VAL4]], %[[CST1]] : i32 -// CHECK-DAG: %[[VAL6:.+]] = arith.addi %[[IDX2_CAST]], %[[VAL5]] : i32 -// CHECK-DAG: %[[VAL7:.+]] = arith.muli %[[VAL6]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL8:.+]] = arith.addi %[[VAL7]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL8_CAST:.+]] = arith.uitofp %[[VAL8]] : i32 to f32 -// CHECK-DAG: %[[VAL6:.+]] = arith.mulf %[[VAL8_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL7:.+]] = arith.addf %[[VAL6]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL7]] : f32 -// CHECK-NEXT: -> tensor<2x2x2xf32> - -// ----- - -// CHECK-LABEL: func @rng_uniform_dynamic_1d -func.func @rng_uniform_dynamic_1d(%min: tensor, %max: tensor, %shape: tensor<1xi32>) -> tensor { - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[EX:.+]] = tensor.extract %{{.*}}[%[[C0]]] -// CHECK-DAG: %[[IND:.+]] = arith.index_cast %[[EX]] : i32 to index -// CHECK-DAG: %{{.+}} = tensor.empty(%[[IND]]) : tensor -// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL2_CAST:.+]] = arith.uitofp %[[VAL2]] : i32 to f32 -// CHECK-DAG: %[[VAL4:.+]] = arith.mulf %[[VAL2_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addf %[[VAL4]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL5]] : f32 -// CHECK-NEXT: -> tensor - -// ----- - -// CHECK-LABEL: func.func @three_fry_i64( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) - return %output_state, %output : tensor<2xi64>, tensor<8xi64> -} -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5 : i32 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 4 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : i32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 24 : i32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 16 : i32 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : i32 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 29 : i32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 6 : i32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 26 : i32 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 17 : i32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 15 : i32 -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 19 : i32 -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 13 : i32 -// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 466688986 : i32 -// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 8 : i64 -// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 32 : i64 -// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_20:.*]] = arith.constant 1 : index - -// CHECK-DAG: %[[VAL_21:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[VAL_20]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_22:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[VAL_19]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_23:.*]] = arith.trunci %[[VAL_22]] : i64 to i32 -// CHECK-DAG: %[[VAL_24:.*]] = arith.shrui %[[VAL_22]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_25:.*]] = arith.trunci %[[VAL_24]] : i64 to i32 -// CHECK-DAG: %[[VAL_26:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_27:.*]] = tensor.empty() : tensor<8xi64> -// CHECK-DAG: %[[VAL_28:.*]] = arith.xori %[[VAL_23]], %[[VAL_16]] : i32 -// CHECK-DAG: %[[VAL_29:.*]] = arith.xori %[[VAL_28]], %[[VAL_25]] : i32 - -// CHECK: %[[GENERIC:.*]] = linalg.generic -// CHECK-SAME: {indexing_maps = [#map], iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[VAL_27]] : tensor<8xi64>) - -// CHECK: ^bb0(%[[VAL_31:.*]]: i64): - -// CHECK-DAG: %[[VAL_32:.*]] = linalg.index 0 : index -// CHECK-DAG: %[[VAL_33:.*]] = arith.index_cast %[[VAL_32]] : index to i64 -// CHECK-DAG: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_21]] : i64 -// CHECK-DAG: %[[VAL_35:.*]] = arith.trunci %[[VAL_34]] : i64 to i32 -// CHECK-DAG: %[[VAL_36:.*]] = arith.shrui %[[VAL_34]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_37:.*]] = arith.trunci %[[VAL_36]] : i64 to i32 - -// CHECK-DAG: %[[VAL_38:.*]] = arith.addi %[[VAL_35]], %[[VAL_23]] : i32 -// CHECK-DAG: %[[VAL_39:.*]] = arith.addi %[[VAL_37]], %[[VAL_25]] : i32 - -// CHECK-DAG: %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : i32 -// CHECK-DAG: %[[VAL_41:.*]] = arith.shli %[[VAL_39]], %[[VAL_15]] : i32 -// CHECK-DAG: %[[VAL_42:.*]] = arith.shrui %[[VAL_39]], %[[VAL_14]] : i32 -// CHECK-DAG: %[[VAL_43:.*]] = arith.ori %[[VAL_41]], %[[VAL_42]] : i32 -// CHECK-DAG: %[[VAL_44:.*]] = arith.xori %[[VAL_40]], %[[VAL_43]] : i32 - -// CHECK-DAG: %[[VAL_45:.*]] = arith.addi %[[VAL_40]], %[[VAL_44]] : i32 -// CHECK-DAG: %[[VAL_46:.*]] = arith.shli %[[VAL_44]], %[[VAL_13]] : i32 -// CHECK-DAG: %[[VAL_47:.*]] = arith.shrui %[[VAL_44]], %[[VAL_12]] : i32 -// CHECK-DAG: %[[VAL_48:.*]] = arith.ori %[[VAL_46]], %[[VAL_47]] : i32 -// CHECK-DAG: %[[VAL_49:.*]] = arith.xori %[[VAL_45]], %[[VAL_48]] : i32 - -// CHECK-DAG: %[[VAL_50:.*]] = arith.addi %[[VAL_45]], %[[VAL_49]] : i32 -// CHECK-DAG: %[[VAL_51:.*]] = arith.shli %[[VAL_49]], %[[VAL_11]] : i32 -// CHECK-DAG: %[[VAL_52:.*]] = arith.shrui %[[VAL_49]], %[[VAL_10]] : i32 -// CHECK-DAG: %[[VAL_53:.*]] = arith.ori %[[VAL_51]], %[[VAL_52]] : i32 -// CHECK-DAG: %[[VAL_54:.*]] = arith.xori %[[VAL_50]], %[[VAL_53]] : i32 - -// CHECK-DAG: %[[VAL_55:.*]] = arith.addi %[[VAL_50]], %[[VAL_54]] : i32 -// CHECK-DAG: %[[VAL_56:.*]] = arith.shli %[[VAL_54]], %[[VAL_10]] : i32 -// CHECK-DAG: %[[VAL_57:.*]] = arith.shrui %[[VAL_54]], %[[VAL_11]] : i32 -// CHECK-DAG: %[[VAL_58:.*]] = arith.ori %[[VAL_56]], %[[VAL_57]] : i32 -// CHECK-DAG: %[[VAL_59:.*]] = arith.xori %[[VAL_55]], %[[VAL_58]] : i32 - -// CHECK-DAG: %[[VAL_60:.*]] = arith.addi %[[VAL_55]], %[[VAL_25]] : i32 -// CHECK-DAG: %[[VAL_61:.*]] = arith.addi %[[VAL_59]], %[[VAL_29]] : i32 -// CHECK-DAG: %[[VAL_62:.*]] = arith.addi %[[VAL_61]], %[[VAL_9]] : i32 -// CHECK-DAG: %[[VAL_63:.*]] = arith.addi %[[VAL_60]], %[[VAL_62]] : i32 -// CHECK-DAG: %[[VAL_64:.*]] = arith.shli %[[VAL_62]], %[[VAL_12]] : i32 -// CHECK-DAG: %[[VAL_65:.*]] = arith.shrui %[[VAL_62]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_66:.*]] = arith.ori %[[VAL_64]], %[[VAL_65]] : i32 - -// CHECK: linalg.yield %[[YIELDED:.*]] : i64 - -// Set the updated state. -// CHECK: %[[VAL_159:.*]] = tensor.insert %[[VAL_26]] into %[[ARG0]]{{\[}}%[[VAL_20]]] : tensor<2xi64> - -// CHECK: return %[[VAL_159]], %[[GENERIC:.*]] : tensor<2xi64>, tensor<8xi64> - -// ----- - -// CHECK-LABEL: func.func @three_fry_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) - return %output_state, %output : tensor<2xi64>, tensor<8xi32> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi32> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi32>, tensor<4xi32>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi32> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi32>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi32> into tensor<8xi32> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] - - -// ----- - -// CHECK-LABEL: func.func @three_fry_odd_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_odd_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) - return %output_state, %output : tensor<2xi64>, tensor<7x11xi32> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C42:.+]] = arith.constant 42 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C42]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<42xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<42xi32> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<42xi32>, tensor<42xi32>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [7, 6, 1] : tensor<4xi32> into tensor<7x6x1xi32> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [7, 6, 1] : tensor<4xi32> into tensor<7x6x1xi32> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<7x6x2xi32> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<7x6x2xi32>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %10 -// CHECK-SAME{literal}: [[0], [1, 2]] : tensor<7x6x2xi32> into tensor<7x12xi32> - -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]][0, 0] [7, 11] [1, 1] -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: return %[[INSERTED]], %[[SLICE]] : tensor<2xi64>, tensor<7x11xi32> - -// ----- - -// CHECK-LABEL: func.func @three_fry_i16 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i16(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) - return %output_state, %output : tensor<2xi64>, tensor<8xi16> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi16> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi16> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi16>, tensor<4xi16>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi16> into tensor<4x1xi16> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi16> into tensor<4x1xi16> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi16> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi16>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi16> into tensor<8xi16> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi16> - -// ----- - -// CHECK-LABEL: func.func @three_fry_i8 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) - return %output_state, %output : tensor<2xi64>, tensor<8xi8> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi8> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi8> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi8>, tensor<4xi8>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi8> into tensor<4x1xi8> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi8> into tensor<4x1xi8> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi8> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi8>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi8> into tensor<8xi8> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi8> - -// ----- - -// CHECK-LABEL: func.func @philox_i64 -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi64> - -func.func @philox_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) - return %output_state, %output : tensor<2xi64>, tensor<8xi64> -} - -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant -1767562579 : i32 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1879881855 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant -616729560 : i32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -239350328 : i32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 534103459 : i32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1401181199 : i32 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1684936478 : i32 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -1253254570 : i32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant -1459197799 : i32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 387276957 : i32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant -308364780 : i32 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 2027808484 : i32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 842468239 : i32 -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -626627285 : i32 -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 1993301258 : i32 -// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 1013904242 : i32 -// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 3449720151 : i64 -// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 3528531795 : i64 -// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_20:.*]] = arith.constant -1150833019 : i32 -// CHECK-DAG: %[[VAL_21:.*]] = arith.constant -1640531527 : i32 -// CHECK-DAG: %[[VAL_22:.*]] = arith.constant 4 : i64 -// CHECK-DAG: %[[VAL_23:.*]] = arith.constant 32 : i64 -// CHECK-DAG: %[[VAL_24:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_19]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_26:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_24]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_27:.*]] = arith.trunci %[[VAL_26]] : i64 to i32 -// CHECK-DAG: %[[VAL_28:.*]] = arith.shrui %[[VAL_26]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_29:.*]] = arith.trunci %[[VAL_28]] : i64 to i32 -// CHECK-DAG: %[[VAL_30:.*]] = arith.addi %[[VAL_25]], %[[VAL_22]] : i64 -// CHECK-DAG: %[[VAL_31:.*]] = tensor.empty() : tensor<4xi64> -// CHECK-DAG: %[[VAL_32:.*]] = tensor.empty() : tensor<4xi64> -// CHECK-DAG: %[[VAL_33:.*]]:2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} outs(%[[VAL_31]], %[[VAL_32]] : tensor<4xi64>, tensor<4xi64>) { -// CHECK-DAG: ^bb0(%[[VAL_34:.*]]: i64, %[[VAL_35:.*]]: i64): -// CHECK-DAG: %[[VAL_36:.*]] = linalg.index 0 : index -// CHECK-DAG: %[[VAL_37:.*]] = arith.index_cast %[[VAL_36]] : index to i64 -// CHECK-DAG: %[[VAL_38:.*]] = arith.addi %[[VAL_37]], %[[VAL_25]] : i64 -// CHECK-DAG: %[[VAL_39:.*]] = arith.trunci %[[VAL_38]] : i64 to i32 -// CHECK-DAG: %[[VAL_40:.*]] = arith.shrui %[[VAL_38]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_41:.*]] = arith.trunci %[[VAL_40]] : i64 to i32 - -// CHECK-DAG: %[[VAL_42:.*]] = arith.extui %[[VAL_39]] : i32 to i64 -// CHECK-DAG: %[[VAL_43:.*]] = arith.muli %[[VAL_42]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_44:.*]] = arith.shrui %[[VAL_43]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_45:.*]] = arith.trunci %[[VAL_44]] : i64 to i32 -// CHECK-DAG: %[[VAL_46:.*]] = arith.trunci %[[VAL_43]] : i64 to i32 -// CHECK-DAG: %[[VAL_47:.*]] = arith.extui %[[VAL_27]] : i32 to i64 -// CHECK-DAG: %[[VAL_48:.*]] = arith.muli %[[VAL_47]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_49:.*]] = arith.shrui %[[VAL_48]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_50:.*]] = arith.trunci %[[VAL_49]] : i64 to i32 -// CHECK-DAG: %[[VAL_51:.*]] = arith.trunci %[[VAL_48]] : i64 to i32 -// CHECK-DAG: %[[VAL_52:.*]] = arith.xori %[[VAL_50]], %[[VAL_41]] : i32 -// CHECK-DAG: %[[VAL_53:.*]] = arith.xori %[[VAL_52]], %[[VAL_27]] : i32 - -// CHECK-DAG: %[[VAL_54:.*]] = arith.addi %[[VAL_27]], %[[VAL_21]] : i32 -// CHECK-DAG: %[[VAL_55:.*]] = arith.addi %[[VAL_29]], %[[VAL_20]] : i32 -// CHECK-DAG: %[[VAL_56:.*]] = arith.extui %[[VAL_53]] : i32 to i64 -// CHECK-DAG: %[[VAL_57:.*]] = arith.muli %[[VAL_56]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_58:.*]] = arith.shrui %[[VAL_57]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_59:.*]] = arith.trunci %[[VAL_58]] : i64 to i32 -// CHECK-DAG: %[[VAL_60:.*]] = arith.trunci %[[VAL_57]] : i64 to i32 -// CHECK-DAG: %[[VAL_61:.*]] = arith.extui %[[VAL_45]] : i32 to i64 -// CHECK-DAG: %[[VAL_62:.*]] = arith.muli %[[VAL_61]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_63:.*]] = arith.shrui %[[VAL_62]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_64:.*]] = arith.trunci %[[VAL_63]] : i64 to i32 -// CHECK-DAG: %[[VAL_65:.*]] = arith.trunci %[[VAL_62]] : i64 to i32 -// CHECK-DAG: %[[VAL_66:.*]] = arith.xori %[[VAL_64]], %[[VAL_51]] : i32 -// CHECK-DAG: %[[VAL_67:.*]] = arith.xori %[[VAL_66]], %[[VAL_54]] : i32 -// CHECK-DAG: %[[VAL_68:.*]] = arith.xori %[[VAL_59]], %[[VAL_46]] : i32 -// CHECK-DAG: %[[VAL_69:.*]] = arith.xori %[[VAL_68]], %[[VAL_55]] : i32 - -// CHECK-DAG: %[[VAL_70:.*]] = arith.addi %[[VAL_27]], %[[VAL_16]] : i32 -// CHECK-DAG: %[[VAL_71:.*]] = arith.addi %[[VAL_29]], %[[VAL_15]] : i32 -// CHECK-DAG: %[[VAL_72:.*]] = arith.extui %[[VAL_67]] : i32 to i64 -// CHECK-DAG: %[[VAL_73:.*]] = arith.muli %[[VAL_72]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_74:.*]] = arith.shrui %[[VAL_73]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_75:.*]] = arith.trunci %[[VAL_74]] : i64 to i32 -// CHECK-DAG: %[[VAL_76:.*]] = arith.trunci %[[VAL_73]] : i64 to i32 -// CHECK-DAG: %[[VAL_77:.*]] = arith.extui %[[VAL_69]] : i32 to i64 -// CHECK-DAG: %[[VAL_78:.*]] = arith.muli %[[VAL_77]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_79:.*]] = arith.shrui %[[VAL_78]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_80:.*]] = arith.trunci %[[VAL_79]] : i64 to i32 -// CHECK-DAG: %[[VAL_81:.*]] = arith.trunci %[[VAL_78]] : i64 to i32 -// CHECK: %[[VAL_82:.*]] = arith.xori %[[VAL_80]], %[[VAL_65]] : i32 -// CHECK-DAG: %[[VAL_83:.*]] = arith.xori %[[VAL_82]], %[[VAL_70]] : i32 -// CHECK-DAG: %[[VAL_84:.*]] = arith.xori %[[VAL_75]], %[[VAL_60]] : i32 -// CHECK-DAG: %[[VAL_85:.*]] = arith.xori %[[VAL_84]], %[[VAL_71]] : i32 - -// CHECK-DAG: %[[VAL_86:.*]] = arith.addi %[[VAL_27]], %[[VAL_14]] : i32 -// CHECK-DAG: %[[VAL_87:.*]] = arith.addi %[[VAL_29]], %[[VAL_13]] : i32 -// CHECK-DAG: %[[VAL_88:.*]] = arith.extui %[[VAL_83]] : i32 to i64 -// CHECK-DAG: %[[VAL_89:.*]] = arith.muli %[[VAL_88]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_90:.*]] = arith.shrui %[[VAL_89]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_91:.*]] = arith.trunci %[[VAL_90]] : i64 to i32 -// CHECK-DAG: %[[VAL_92:.*]] = arith.trunci %[[VAL_89]] : i64 to i32 -// CHECK-DAG: %[[VAL_93:.*]] = arith.extui %[[VAL_85]] : i32 to i64 -// CHECK-DAG: %[[VAL_94:.*]] = arith.muli %[[VAL_93]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_95:.*]] = arith.shrui %[[VAL_94]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_96:.*]] = arith.trunci %[[VAL_95]] : i64 to i32 -// CHECK-DAG: %[[VAL_97:.*]] = arith.trunci %[[VAL_94]] : i64 to i32 -// CHECK-DAG: %[[VAL_98:.*]] = arith.xori %[[VAL_96]], %[[VAL_81]] : i32 -// CHECK-DAG: %[[VAL_99:.*]] = arith.xori %[[VAL_98]], %[[VAL_86]] : i32 -// CHECK-DAG: %[[VAL_100:.*]] = arith.xori %[[VAL_91]], %[[VAL_76]] : i32 -// CHECK-DAG: %[[VAL_101:.*]] = arith.xori %[[VAL_100]], %[[VAL_87]] : i32 - -// CHECK: linalg.yield %[[YIELDED_1:.*]], %[[YIELDED_2:.*]] : i64, i64 -// CHECK-DAG: %[[VAL_206:.*]] = tensor.expand_shape %[[VAL_207:.*]]#0 {{\[\[}}0, 1]] output_shape [4, 1] : tensor<4xi64> into tensor<4x1xi64> -// CHECK-DAG: %[[VAL_208:.*]] = tensor.expand_shape %[[VAL_207]]#1 {{\[\[}}0, 1]] output_shape [4, 1] : tensor<4xi64> into tensor<4x1xi64> -// CHECK-DAG: %[[VAL_209:.*]] = tensor.empty() : tensor<4x2xi64> -// CHECK-DAG: %[[VAL_213:.*]] = tensor.insert %[[VAL_30]] into %[[VAL_0]]{{\[}}%[[VAL_19]]] : tensor<2xi64> - -// CHECK: return %[[VAL_213]], %[[GENERIC:.*]] : tensor<2xi64>, tensor<8xi64> - - - -// ----- - -// CHECK-LABEL: func.func @philox_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @philox_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) - return %output_state, %output : tensor<2xi64>, tensor<8xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) - -// CHECK: %[[CONCAT:.+]] = linalg.generic - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<2x4xi32> into tensor<8xi32> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] - - -// ----- - -// CHECK-LABEL: func.func @philox_128_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<3xi64> -func.func @philox_128_i32(%arg0: tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>) - return %output_state, %output : tensor<3xi64>, tensor<8xi32> -} - -// ----- - -func.func @philox_i32_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) - return %output_state, %output : tensor<2xi64>, tensor<7x11xi32> -} - -// CHECK-LABEL: func.func @philox_i32_odd -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C20]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<20xi32>, tensor<20xi32>, tensor<20xi32>, tensor<20xi32>) - - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<20x4xi32> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<20x4xi32>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] - - -// CHECK: %[[VAL_213:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1]] output_shape [80, 1] : tensor<80xi32> into tensor<80x1xi32> -// CHECK: %[[VAL_214:.*]] = tensor.extract_slice %[[VAL_213]][0, 0] [77, 1] [1, 1] : tensor<80x1xi32> to tensor<77x1xi32> -// CHECK: %[[VAL_215:.*]] = tensor.collapse_shape %[[VAL_214]] {{\[\[}}0, 1]] : tensor<77x1xi32> into tensor<77xi32> -// CHECK: %[[VAL_216:.*]] = tensor.expand_shape %[[VAL_215]] {{\[\[}}0, 1]] output_shape [7, 11] : tensor<77xi32> into tensor<7x11xi32> -// CHECK: %[[VAL_217:.*]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]]{{\[}}%[[C1]]] : tensor<2xi64> -// CHECK: return %[[VAL_217]], %[[VAL_216]] : tensor<2xi64>, tensor<7x11xi32> - - -// ----- - - -func.func @philox_i64_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi64>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi64>) - return %output_state, %output : tensor<2xi64>, tensor<3x5xi64> -} - -// CHECK-LABEL: func.func @philox_i64_odd -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C8]] : i64 - -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<8xi64> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<8xi64> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST2]], %[[DEST3]] : tensor<8xi64>, tensor<8xi64>) - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xi64> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xi64>) - -// CHECK-DAG: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] {{\[\[}}0, 1]] : tensor<8x2xi64> into tensor<16xi64> - - -// CHECK-DAG: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1]] output_shape [16, 1] : tensor<16xi64> into tensor<16x1xi64> -// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0] [15, 1] [1, 1] : tensor<16x1xi64> to tensor<15x1xi64> -// CHECK-DAG: %[[EXPAND_2:.*]] = tensor.collapse_shape %[[SLICE]] {{\[\[}}0, 1]] : tensor<15x1xi64> into tensor<15xi64> -// CHECK-DAG: %[[RESHAPE:.*]] = tensor.expand_shape %[[EXPAND_2]] {{\[\[}}0, 1]] output_shape [3, 5] : tensor<15xi64> into tensor<3x5xi64> -// CHECK-DAG: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: return %[[INSERTED]], %[[RESHAPE]] - -// ----- - -func.func @philox_i16(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) - return %output_state, %output : tensor<2xi64>, tensor<8xi16> -} - -// CHECK-LABEL: func.func @philox_i16 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi16>, tensor<2xi16>, tensor<2xi16>, tensor<2xi16>) - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4xi16> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4xi16>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<2x4xi16> into tensor<8xi16> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] - -// ----- - -func.func @philox_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) - return %output_state, %output : tensor<2xi64>, tensor<8xi8> -} - -// CHECK-LABEL: func.func @philox_i8 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi8>, tensor<2xi8>, tensor<2xi8>, tensor<2xi8>) - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4xi8> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4xi8>) diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_reduce.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_reduce.mlir deleted file mode 100644 index 313f645bcb1d..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_reduce.mlir +++ /dev/null @@ -1,794 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// RUN: iree-opt %s --iree-stablehlo-to-linalg="enable-primitive-ops=true" \ -// RUN: --split-input-file --canonicalize | \ -// RUN: FileCheck %s --check-prefix=CHECK-PRIMITIVE - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: @reduce_add -// CHECK-PRIMITIVE-LABEL: @reduce_add -func.func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32> { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.add %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array, someattr} : (tensor<5x4xi32>, tensor) -> tensor<5xi32> - func.return %0 : tensor<5xi32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) -// CHECK-SAME: {someattr} -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK-PRIMITIVE: linalg.reduce { arith.addi {overflowFlags = #arith.overflow} } -// CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [1] {someattr} - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: @reduce_dim0 -// CHECK-PRIMITIVE-LABEL: @reduce_dim0 -func.func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<4xi32> { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.maximum %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<5x4xi32>, tensor) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxsi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK-PRIMITIVE: linalg.reduce { arith.maxsi } -// CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [0] - -// ----- - -func.func @reduce_dynamic_output(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.maximum %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<5x4xi32>, tensor) -> tensor - func.return %0 : tensor -} - -// Regression test: just check that this lowers successfully. -// CHECK-LABEL: @reduce_dynamic_output -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @reduce_dynamic_output -// CHECK-PRIMITIVE: linalg.reduce - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: @reduce_init_const -func.func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> { - %cst = arith.constant dense<0xFF800000> : tensor - %0 = "stablehlo.reduce"(%arg0, %cst) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %1 = stablehlo.add %arg1, %arg2 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<1x10xf32>, tensor) -> tensor<1xf32> - func.return %0 : tensor<1xf32> -} -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<1xf32>) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addf %[[RHS_IN]], %[[LHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK: @reduce_multi_dimensions -func.func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>, - %arg1: tensor) -> tensor<4xi32> { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<5x4x3xi32>, tensor) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -// CHECK-LABEL: @reduce_lexicographic_min_complex -// CHECK-PRIMITIVE-LABEL: @reduce_lexicographic_min_complex -func.func @reduce_lexicographic_min_complex(%arg0: tensor>, - %arg1: tensor>) - -> tensor> { - %0 = stablehlo.reduce(%arg0 init: %arg1) - across dimensions = [0, 1, 2] - : (tensor>, tensor>) -> tensor> - reducer(%arg3: tensor>, %arg4: tensor>) { - %1 = stablehlo.real %arg3 : (tensor>) -> tensor - %2 = stablehlo.convert %arg4 : (tensor>) -> tensor - %3 = "stablehlo.compare"(%1, %2) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - %4 = stablehlo.imag %arg3 : (tensor>) -> tensor - %5 = stablehlo.imag %arg4 : (tensor>) -> tensor - %6 = "stablehlo.compare"(%4, %5) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - %7 = "stablehlo.compare"(%1, %2) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - %8 = "stablehlo.select"(%3, %6, %7) - : (tensor, tensor, tensor) -> tensor - %9 = "stablehlo.select"(%8, %arg3, %arg4) - : (tensor, tensor>, tensor>) - -> tensor> - "stablehlo.return"(%9) : (tensor>) -> () - } - return %0 : tensor> -} - -// CHECK: linalg.generic -// CHECK: complex.re -// CHECK: complex.re -// CHECK: arith.cmpf -// CHECK: complex.im -// CHECK: complex.im -// CHECK: arith.cmpf -// CHECK: arith.cmpf -// CHECK: arith.select - -// CHECK-PRIMITIVE: linalg.reduce -// CHECK-PRIMITIVE: complex.re -// CHECK-PRIMITIVE: complex.re -// CHECK-PRIMITIVE: arith.cmpf -// CHECK-PRIMITIVE: complex.im -// CHECK-PRIMITIVE: complex.im -// CHECK-PRIMITIVE: arith.cmpf -// CHECK-PRIMITIVE: arith.cmpf -// CHECK-PRIMITIVE: arith.select - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor -func.func @reduce_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.add %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty(%[[DIM1]]) -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @variadic_reduce -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-LABEL: func @variadic_reduce -// CHECK-PRIMITIVE-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @variadic_reduce(%arg0: tensor<9x2xi32>, %arg1: tensor<9x2xi32>) -> (tensor<2xi32>, tensor<2xi32>) { - %cst0 = stablehlo.constant dense<-2147483648> : tensor - %cst1 = stablehlo.constant dense<0> : tensor - %res0, %res1 = "stablehlo.reduce"(%arg0, %arg1, %cst0, %cst1) ({ - ^bb0(%arg2: tensor, %arg3: tensor, %arg15: tensor, %arg16: tensor): - %669 = "stablehlo.compare"(%arg2, %arg15) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor - %670 = "stablehlo.select"(%669, %arg2, %arg15) : (tensor, tensor, tensor) -> tensor - %671 = "stablehlo.compare"(%arg2, %arg15) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor - %672 = stablehlo.minimum %arg3, %arg16 : tensor - %673 = "stablehlo.select"(%669, %arg3, %arg16) : (tensor, tensor, tensor) -> tensor - %674 = "stablehlo.select"(%671, %672, %673) : (tensor, tensor, tensor) -> tensor - "stablehlo.return"(%670, %674) : (tensor, tensor) -> () - }) {dimensions = array} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor, tensor) -> (tensor<2xi32>, tensor<2xi32>) - func.return %res0, %res1 : tensor<2xi32>, tensor<2xi32> -} -// CHECK-DAG: %[[CST0:.*]] = arith.constant -2147483648 : i32 -// CHECK-DAG: %[[CST1:.*]] = arith.constant 0 : i32 -// CHECK: %[[INIT0:.*]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK: %[[INIT1:.*]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK: %[[RES:.+]]:2 = linalg.generic { -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>) -// CHECK-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>) -// CHECK-NEXT: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32): -// CHECK-NEXT: %[[T1:.*]] = arith.cmpi sge, %[[OUT0]], %[[IN0]] : i32 -// CHECK-NEXT: %[[T2:.*]] = arith.select %[[T1]], %[[OUT0]], %[[IN0]] : i32 -// CHECK-NEXT: %[[T3:.*]] = arith.cmpi eq, %[[OUT0]], %[[IN0]] : i32 -// CHECK-NEXT: %[[T4:.*]] = arith.minsi %[[OUT1:.*]], %[[IN1]] : i32 -// CHECK-NEXT: %[[T5:.*]] = arith.select %[[T1]], %[[OUT1]], %[[IN1]] : i32 -// CHECK-NEXT: %[[T6:.*]] = arith.select %[[T3]], %[[T4]], %[[T5]] : i32 -// CHECK-NEXT: linalg.yield %[[T2]], %[[T6]] - -// CHECK-PRIMITIVE-DAG: %[[CST0:.*]] = arith.constant -2147483648 : i32 -// CHECK-PRIMITIVE-DAG: %[[CST1:.*]] = arith.constant 0 : i32 -// CHECK-PRIMITIVE: %[[INIT0:.*]] = tensor.empty() : tensor<2xi32> -// CHECK-PRIMITIVE: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK-PRIMITIVE: %[[INIT1:.*]] = tensor.empty() : tensor<2xi32> -// CHECK-PRIMITIVE: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.reduce -// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [0] -// CHECK-PRIMITIVE-NEXT: (%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32) { -// CHECK-PRIMITIVE-NEXT: %[[T1:.*]] = arith.cmpi sge, %[[OUT0]], %[[IN0]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T2:.*]] = arith.select %[[T1]], %[[OUT0]], %[[IN0]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T3:.*]] = arith.cmpi eq, %[[OUT0]], %[[IN0]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T4:.*]] = arith.minsi %[[OUT1:.*]], %[[IN1]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T5:.*]] = arith.select %[[T1]], %[[OUT1]], %[[IN1]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T6:.*]] = arith.select %[[T3]], %[[T4]], %[[T5]] : i32 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[T2]], %[[T6]] - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @variadic_diff_type_reduce -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-LABEL: func @variadic_diff_type_reduce -// CHECK-PRIMITIVE-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @variadic_diff_type_reduce(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128xf32>, tensor<128xi32>) { - %cst0 = stablehlo.constant dense<1.0> : tensor - %cst1 = stablehlo.constant dense<1> : tensor - %res0, %res1 = "stablehlo.reduce"(%arg0, %arg1, %cst0, %cst1) ({ - ^bb0(%arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor): - %0 = "stablehlo.compare"(%arg7, %arg9) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor - %1 = "stablehlo.select"(%0, %arg7, %arg9) : (tensor, tensor, tensor) -> tensor - %2 = "stablehlo.select"(%0, %arg8, %arg10) : (tensor, tensor, tensor) -> tensor - "stablehlo.return"(%1, %2) : (tensor, tensor) -> () - }) {dimensions = array} : (tensor<128x10xf32>, tensor<128x10xi32>, tensor, tensor) ->(tensor<128xf32>, tensor<128xi32>) - func.return %res0, %res1 : tensor<128xf32>, tensor<128xi32> -} -// CHECK-DAG: %[[CST0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 -// CHECK: %[[INIT0:.*]] = tensor.empty() : tensor<128xf32> -// CHECK: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK: %[[INIT1:.*]] = tensor.empty() : tensor<128xi32> -// CHECK: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK: %[[RES:.+]]:2 = linalg.generic { -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<128x10xf32>, tensor<128x10xi32>) -// CHECK-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<128xf32>, tensor<128xi32>) -// CHECK-NEXT: ^bb0(%[[LHS0:.*]]: f32, %[[LHS1:.*]]: i32, %[[RHS0:.*]]: f32, %[[RHS1:.*]]: i32): -// CHECK-NEXT: %[[B0:.*]] = arith.cmpf oge, %[[RHS0]], %[[LHS0]] : f32 -// CHECK-NEXT: %[[RES0:.*]] = arith.select %[[B0]], %[[RHS0]], %[[LHS0]] : f32 -// CHECK-NEXT: %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32 -// CHECK-NEXT: linalg.yield %[[RES0]], %[[RES1]] : f32, i32 - -// CHECK-PRIMITIVE-DAG: %[[CST0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-PRIMITIVE-DAG: %[[CST1:.*]] = arith.constant 1 : i32 -// CHECK-PRIMITIVE: %[[INIT0:.*]] = tensor.empty() : tensor<128xf32> -// CHECK-PRIMITIVE: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK-PRIMITIVE: %[[INIT1:.*]] = tensor.empty() : tensor<128xi32> -// CHECK-PRIMITIVE: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.reduce -// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<128x10xf32>, tensor<128x10xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<128xf32>, tensor<128xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [1] -// CHECK-PRIMITIVE-NEXT: (%[[LHS0:.*]]: f32, %[[LHS1:.*]]: i32, %[[RHS0:.*]]: f32, %[[RHS1:.*]]: i32) { -// CHECK-PRIMITIVE-NEXT: %[[B0:.*]] = arith.cmpf oge, %[[RHS0]], %[[LHS0]] : f32 -// CHECK-PRIMITIVE-NEXT: %[[RES0:.*]] = arith.select %[[B0]], %[[RHS0]], %[[LHS0]] : f32 -// CHECK-PRIMITIVE-NEXT: %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[RES0]], %[[RES1]] : f32, i32 - -// ----- - -// Make sure we do not crash on unsupported reductions. - -// CHECK-LABEL: func.func @reduce_noop -// CHECK: stablehlo.reduce -// CHECK-PRIMITIVE-LABEL: func.func @reduce_noop -// CHECK-PRIMITIVE: stablehlo.reduce -func.func @reduce_noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - %0 = stablehlo.constant dense<0.000000e+00> : tensor - %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - reducer(%arg1: tensor, %arg2: tensor) { - %4 = stablehlo.add %arg1, %arg2 : tensor - stablehlo.return %4 : tensor - } - func.return %1 : tensor<4x8xf32> -} - -// CHECK-LABEL: func.func @reduce_zero_ext -// CHECK: stablehlo.reduce -// CHECK-PRIMITIVE-LABEL: func.func @reduce_zero_ext -// CHECK-PRIMITIVE: stablehlo.reduce -func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor { - %0 = stablehlo.constant dense : tensor - %1 = stablehlo.constant dense : tensor<0xi1> - %2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> - %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32> - %4 = stablehlo.constant dense<0> : tensor - %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor) -> tensor - reducer(%arg1: tensor, %arg2: tensor) { - %6 = stablehlo.add %arg1, %arg2 : tensor - stablehlo.return %6 : tensor - } - return %5 : tensor -} - -// ----- - -// CHECK-LABEL: func @reduce_window_min_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_min_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.minimum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array, - someattr} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %0 : tensor<1x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_min -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: someattr, -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_max_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_max_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.maximum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %0 : tensor<1x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_sum_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %0 : tensor<1x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_max_nhwc_with_cst -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -func.func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x17x17x64xf32>) -> tensor<1x8x8x64xf32> { - %0 = arith.constant dense<0xFF800000> : tensor - %1 = "stablehlo.reduce_window"(%arg0, %0) ({ - ^bb0(%arg1: tensor, %arg2 : tensor): - %2 = stablehlo.maximum %arg1, %arg2 : tensor - "stablehlo.return"(%2) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %1 : tensor<1x8x8x64xf32> -} - -// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32 -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_max_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_sum_max_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) { - %0:2 = "stablehlo.reduce_window"(%arg0, %arg0, %arg1, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor, %arg4: tensor, %arg5 : tensor): - %1 = stablehlo.add %arg2, %arg4 : tensor - %2 = stablehlo.maximum %arg3, %arg5 : tensor - "stablehlo.return"(%1, %2) : (tensor, tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor<1x17x17x64xf32>, tensor, tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) - func.return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32> -} - -// CHECK: %[[WINDOW0:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT0:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL0:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES0:.+]] = linalg.pooling_nhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW0]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[WINDOW1:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT1:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL1:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL1:.+]] = linalg.fill ins(%[[INIT_VAL1]] : f32) outs(%[[INIT1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES1:.+]] = linalg.pooling_nhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW1]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: return %[[RES0]], %[[RES1]] - -// ----- - -// Just check that this lowers successfully. -// CHECK-LABEL: func @reduce_window_unsigned -func.func @reduce_window_unsigned(%arg0: tensor<1x1xui32>) -> tensor<1x1xui32> { - %0 = stablehlo.constant dense<0> : tensor - %1 = "stablehlo.reduce_window"(%arg0, %0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - stablehlo.return %arg1 : tensor - }) { - window_dimensions = array, - window_strides = array - } : (tensor<1x1xui32>, tensor) -> tensor<1x1xui32> - return %1 : tensor<1x1xui32> -} - -// ----- - -// CHECK-LABEL: func @dynamic_reduce_window_sum_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @dynamic_reduce_window_sum_nhwc(%arg0: tensor, - %arg1: tensor) -> tensor{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[C3]] -// CHECK: %[[T3:.+]] = arith.divui %[[T2]], %[[C2]] -// CHECK: %[[D1:.+]] = arith.addi %[[T3]], %[[C1]] -// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[C3]] -// CHECK: %[[T3:.+]] = arith.divui %[[T2]], %[[C2]] -// CHECK: %[[D2:.+]] = arith.addi %[[T3]], %[[C1]] -// CHECK: %[[D3:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) : tensor -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor) -> tensor -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @reduce_window_min_ndhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_min_ndhwc(%arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.minimum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x8x64xf32> - func.return %0 : tensor<1x8x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_min -// CHECK-SAME: {dilations = dense<1> : vector<3xi64> -// CHECK-SAME: strides = dense<2> : vector<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_max_ndhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_max_ndhwc(%arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.maximum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x8x64xf32> - func.return %0 : tensor<1x8x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<3xi64> -// CHECK-SAME: strides = dense<2> : vector<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_ndhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_sum_ndhwc(%arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x8x64xf32> - func.return %0 : tensor<1x8x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<3xi64> -// CHECK-SAME: strides = dense<2> : vector<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_ndhwc_dilated_base -// CHECK: linalg.generic -func.func @reduce_window_sum_ndhwc_dilated_base( - %arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x16x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, - window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x16x64xf32> - func.return %0 : tensor<1x8x8x16x64xf32> -} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 * 2, d1 + d2 * 2)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK: func @reduce_window_generic -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic(%arg0: tensor<4x6xf32>, %arg1: tensor) -> tensor<4x7xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x6xf32>, tensor) -> tensor<4x7xf32> - func.return %0 : tensor<4x7xf32> -} -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7xf32> -// CHECK: %[[FILL:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor) outs(%[[INIT]] : tensor<4x7xf32>) -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 -// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 -// CHECK: linalg.yield %[[IN]] : f32 - -// CHECK: %[[PADVAL:.+]] = tensor.extract %arg1[] : tensor -// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1] high[3, 2] -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index -// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index -// CHECK: tensor.yield %[[PADVAL]] : f32 - -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<2xf32> -// CHECK: %[[REDUCE:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[PAD]], %[[WINDOW]] : tensor<7x9xf32>, tensor<2xf32>) outs(%[[FILL]] : tensor<4x7xf32>) { -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 -// CHECK-SAME: %[[IN2:[a-zA-Z0-9_]*]]: f32 -// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 -// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[IN]] : f32 -// CHECK: linalg.yield %[[ADD]] - -// CHECK: return %[[REDUCE]] -// ----- - -// CHECK-LABEL: func @reduce_window_generic_captured_constant -func.func @reduce_window_generic_captured_constant(%arg0: tensor<4x6xf32>, %arg1: tensor) -> tensor<4x7xf32> { - %c2 = stablehlo.constant dense<2.0> : tensor - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - %2 = stablehlo.multiply %1, %c2 : tensor - "stablehlo.return"(%2) : (tensor) -> () - }) {base_dilations = array, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x6xf32>, tensor) -> tensor<4x7xf32> - func.return %0 : tensor<4x7xf32> -} - -// CHECK: %[[C2:.*]] = arith.constant 2.0 -// CHECK: linalg.generic -// CHECK: %[[SUM:.*]] = arith.addf -// CHECK: %[[PROD:.*]] = arith.mulf %[[SUM]], %[[C2]] -// CHECK: linalg.yield %[[PROD]] - -// ----- - -// CHECK-LABEL: func @reduce_window_generic_padding -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic_padding(%arg0: tensor<3x6xf32>, %arg1: tensor) -> tensor<3x7xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<3x6xf32>, tensor) -> tensor<3x7xf32> - func.return %0 : tensor<3x7xf32> -} -// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[0, 1] high[3, 2] -// CHECK: tensor.yield %[[PADVAL]] : f32 - -// ----- - -// CHECK-LABEL: func @reduce_window_generic_base_dilation -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor) -> tensor<3x4xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<3x6xf32>, tensor) -> tensor<3x4xf32> - func.return %0 : tensor<3x4xf32> -} -// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<5x6xf32>) -> tensor<5x6xf32> -// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 0] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<5x6xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_generic_padding_base_dilation -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic_padding_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor) -> tensor<4x7xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<3x6xf32>, tensor) -> tensor<4x7xf32> - func.return %0 : tensor<4x7xf32> -} -// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<8x9xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<8x9xf32>) -> tensor<8x9xf32> -// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 1] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<8x9xf32> - -// ----- - -// CHECK: #[[MAP:.+]] = affine_map<() -> ()> -// CHECK: func @reduce_window_generic_scalar -func.func @reduce_window_generic_scalar(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, padding = dense<> : tensor<0x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]