diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index 0c0469eb335c..0d284e9c6205 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -90,6 +90,9 @@ def __init__(self, repo_map: Dict[str, str]): "@stablehlo//:stablehlo_passes": [ "StablehloPasses", ], + "@stablehlo//:linalg_passes": [ + "StablehloLinalgTransforms", + ], "@stablehlo//:vhlo_ops": [ "VhloOps", ], diff --git a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel index 110c27ae4eac..768cc48b2d83 100644 --- a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel @@ -70,17 +70,8 @@ iree_compiler_cc_library( "LegalizeToLinalgUtils.h", "MapStableHLOToScalarOp.h", "StableHLOCustomCalls.cpp", - "StableHLOToArith.cpp", "StableHLOToIREEInputDialects.cpp", - "StableHLOToLinalg.cpp", - "StableHLOToLinalgConvolution.cpp", - "StableHLOToLinalgDotProd.cpp", "StableHLOToLinalgExt.cpp", - "StableHLOToLinalgPointwise.cpp", - "StableHLOToLinalgRandom.cpp", - "StableHLOToLinalgReduce.cpp", - "TypeConversion.cpp", - "TypeConversion.h", "VerifyCompilerInputLegality.cpp", ], deps = [ @@ -121,6 +112,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:VectorDialect", "@stablehlo//:broadcast_utils", "@stablehlo//:chlo_ops", + "@stablehlo//:linalg_passes", "@stablehlo//:stablehlo_ops", "@stablehlo//:vhlo_ops", ], diff --git a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt index 5afe3a28fdf4..e1e76357b59b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt +++ b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt @@ -56,17 +56,8 @@ iree_cc_library( "LegalizeToLinalgUtils.h" "MapStableHLOToScalarOp.h" "StableHLOCustomCalls.cpp" - "StableHLOToArith.cpp" "StableHLOToIREEInputDialects.cpp" - "StableHLOToLinalg.cpp" - "StableHLOToLinalgConvolution.cpp" - "StableHLOToLinalgDotProd.cpp" "StableHLOToLinalgExt.cpp" - "StableHLOToLinalgPointwise.cpp" - "StableHLOToLinalgRandom.cpp" - "StableHLOToLinalgReduce.cpp" - "TypeConversion.cpp" - "TypeConversion.h" "VerifyCompilerInputLegality.cpp" DEPS ::CHLODecompositionPatterns @@ -99,6 +90,7 @@ iree_cc_library( MLIRTransforms MLIRVectorDialect StablehloBroadcastUtils + StablehloLinalgTransforms StablehloOps VhloOps iree::compiler::Dialect::Flow::IR diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp index 54bf385bb7a5..4520409dee30 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp @@ -13,21 +13,10 @@ #include #include -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" -#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" namespace mlir::iree_compiler::stablehlo { -namespace { -bool hasIntegralShapeType(Operation *op) { - auto stp = llvm::dyn_cast(op->getOperand(0).getType()); - return stp && stp.getElementType().isIntOrIndex(); -} - -} // namespace SmallVector getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) { @@ -42,124 +31,6 @@ getNParallelLoopsAttrs(unsigned nParallelLoops) { return getParallelAndReductionIterators(nParallelLoops, 0); } -Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes) { - return b.create( - loc, llvm::cast(type), dynSizes, - /*copy=*/Value(), - /*memory_space=*/IntegerAttr()); -} - -Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes) { - return b.create( - loc, type.getShape(), type.getElementType(), dynSizes, - llvm::cast(type).getEncoding()); -} - -Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType, - Operation *op, ValueRange operands) { - bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr; - // Collect the sizes for a ranked tensor to be passed as parameter to a - // new tensor initialization operation. This operation only needs the - // dynamic sizes. - SmallVector sizes; - if (resultType.hasRank() && !resultType.hasStaticShape()) { - // Ask the op for its output shape. - auto shapeSource = cast(op); - SmallVector reifiedShapes; - (void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes); - assert(reifiedShapes.size() == 1 && "Expected one reified result"); - // Construct sizes for the required dimensions. - for (const auto &en : llvm::enumerate(resultType.getShape())) { - if (!ShapedType::isDynamic(en.value())) - continue; - sizes.push_back(b.create( - loc, reifiedShapes[0], - ValueRange{b.create(loc, en.index())})); - } - } - return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes) - : getEmptyTensor(b, loc, resultType, sizes); -} - -Value coerceTensorShape(OpBuilder &builder, Location loc, - TypedValue value, ShapedType targetType) { - return builder.createOrFold( - loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()), - value); -} - -LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op) { - auto isRankedTensor = [](Value val) { - return isa(val.getType()); - }; - if (!llvm::all_of(op->getOperands(), isRankedTensor)) - return failure(); - return success(llvm::all_of(op->getResults(), isRankedTensor)); -} - -Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor) { - auto type = cast(tensor.getType()); - Value zero; - // Complex numbers are a special case. - if (auto complexType = llvm::dyn_cast(type.getElementType())) { - auto zeroElement = builder.getZeroAttr(complexType.getElementType()); - auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement}); - zero = builder.create(loc, complexType, zeroAttr); - } else { - auto zeroAttr = builder.getZeroAttr(type.getElementType()); - zero = builder.create(loc, zeroAttr); - } - return builder.create(loc, zero, tensor).result(); -} - -Value preSparsify(Operation *op, llvm::SmallVector &values, Type rtp, - OpBuilder *b) { - // Apply for semi-ring operations that lower to elaborate code - // (any sign-op, or an integral abs-op). - // TODO(peiming, ajcbik): these all can potentially be optimized by applying - // value transform on sparse_tenosr.value memref - if (isa(op) || - (isa(op) && hasIntegralShapeType(op)) || - isa(op)) { - if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) && - !sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType())) - return Value(); - Location loc = op->getLoc(); - auto semiring = b->create(loc, rtp, values[0]); - Type itp = values[0].getType(); - Block *present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc); - b->setInsertionPointToStart(&semiring.getPresentRegion().front()); - values[0] = present->getArgument(0); - return semiring; - } - return Value(); -} - -Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b) { - if (semiring) { - b->create(op->getLoc(), result); - b->setInsertionPointAfter(semiring.getDefiningOp()); - return semiring; - } - return result; -} - -bool allOperandsAreScalarTensors(Operation *op) { - return llvm::all_of(op->getOperands(), [](Value operand) { - auto operandTy = llvm::dyn_cast(operand.getType()); - return operandTy && operandTy.getRank() == 0; - }); -} - -bool isInBodyOfLinalgOps(Operation *op) { - auto *parentOp = op->getParentRegion()->getParentOp(); - return parentOp->getDialect() == - parentOp->getContext()->getLoadedDialect(); -} - SmallVector extract1DVector(DenseIntElementsAttr elements) { SmallVector ret; for (const APInt &element : elements) { diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h index 2b74488c9ef3..5bbe07fdb4fe 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h @@ -15,26 +15,16 @@ #include #include -#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" namespace mlir::iree_compiler::stablehlo { @@ -49,75 +39,9 @@ getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction); SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops); -/// Generates an init sparse tensor. -Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes); - -/// Generates a tensor.empty op. -Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type, - ArrayRef dynSizes); - -/// Generates an empty tensor for the result of the operation, which could be a -/// dense tensor or a sparse tensor. -Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType, - Operation *op, ValueRange operands); - -/// Ensures a tensor has the same shape (not including the element type) as -/// another. -Value coerceTensorShape(OpBuilder &builder, Location loc, - TypedValue value, ShapedType targetType); - -/// Verifies |op|'s semantics by checking if all operands and results have -/// ranged tensor types. -LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op); - -/// Fills |tensor| with a zero constant of the matching type. Returns the new -/// value. -Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor); - -/// Sparsifies a (block of) operation(s) that cannot be handled directly -/// by the sparse compiler but has well-known semi-ring semantics. -/// -/// This yields something of the following form: -/// -/// %result = sparse_tensor.unary %values[0] -/// present={ -/// ^bb1(%val): -/// ... codegen proceeds here using %val .... -/// sparse_tensor.yield -/// } -/// absent={} -/// linalg.yield %result -Value preSparsify(Operation *op, llvm::SmallVector &values, Type rtp, - OpBuilder *b); - -/// Finalizes sparse semi-ring construction. -Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b); - -/// Returns true if all operands are tensors with rank 0. -bool allOperandsAreScalarTensors(Operation *op); - -/// Returns true if parent op is linalg. -bool isInBodyOfLinalgOps(Operation *op); - /// Extracts integer values from the attribute |elements|. SmallVector extract1DVector(DenseIntElementsAttr elements); -/// Returns true if the given |values| is a splat of the given |queryValue|. -inline bool isSplatValue(const ArrayRef &values, int64_t queryValue) { - for (auto value : values) { - if (value != queryValue) { - return false; - } - } - return true; -} - -/// Returns true if the given |attr| is a splat of the given |value|. -inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) { - return attr.isSplat() && attr.getSplatValue() == value; -} - } // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_ diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp index d8a0e66f3cdd..b66efc1f2083 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp @@ -87,7 +87,6 @@ void buildStableHLOInputConversionPassPipelineImpl( stablehlo::createConvertStableHloToLinalgExt()); passManager.addNestedPass(stablehlo::createLegalizeChlo()); passManager.addPass(createConvertStableHloToIreeInputDialects()); - // Ensure conversion completed. passManager.addPass(createReconcileUnrealizedCastsPass()); // Note that some StableHLO ops are left by the above and must resolve via diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.h b/compiler/plugins/input/StableHLO/Conversion/Passes.h index 4b63daeb6905..e6bd1741f4be 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Passes.h +++ b/compiler/plugins/input/StableHLO/Conversion/Passes.h @@ -10,11 +10,7 @@ #include "compiler/plugins/input/StableHLO/Conversion/PassDetail.h" #include "mlir/Pass/Pass.h" -namespace mlir { -class TypeConverter; -namespace iree_compiler::stablehlo { - -std::unique_ptr createStableHloToLinalgTypeConverter(); +namespace mlir::iree_compiler::stablehlo { struct StableHloOptions : public PassPipelineOptions {}; @@ -36,7 +32,6 @@ void buildStableHLOXLAInputConversionPassPipeline( void registerStableHLOConversionPasses(); -} // namespace iree_compiler::stablehlo -} // namespace mlir +} // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_PASSES_H_ diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.td b/compiler/plugins/input/StableHLO/Conversion/Passes.td index 8f7aa2b1ec3d..7852112d453b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Passes.td +++ b/compiler/plugins/input/StableHLO/Conversion/Passes.td @@ -32,15 +32,6 @@ def ConvertStableHloToLinalgExt : // General passes //===----------------------------------------------------------------------===// -def ConvertStableHloToLinalg : - Pass<"iree-stablehlo-to-linalg", "ModuleOp"> { - let summary = "Converts from StableHLO ops to Linalg ops on"; - let options = [Option<"enablePrimitiveOps", "enable-primitive-ops", "bool", - /*default=*/"false", - "Lower to primitive Linalg ops (map, reduce and " - "transpose) when possible, instead of linalg.generic">]; -} - def LegalizeControlFlow : InterfacePass<"iree-stablehlo-legalize-control-flow", "mlir::FunctionOpInterface"> { let summary = "Legalizes from StableHLO control flow to SCF control flow"; diff --git a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h index 8a4380284f52..1a328ff3908b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h +++ b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h @@ -15,12 +15,6 @@ namespace mlir::iree_compiler::stablehlo { // General StableHLO/CHLO lowering patterns. //===----------------------------------------------------------------------===// -/// Populates the patterns that convert from StableHLO to Linalg on tensors. -void populateStableHloToLinalgConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet *patterns, - bool enablePrimitiveOps); - /// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and /// Shape ops. void populateLegalizeChloPatterns(MLIRContext *context, @@ -44,59 +38,12 @@ void populateStableHloToLinalgExtConversionPatterns( MLIRContext *context, TypeConverter &typeConverter, RewritePatternSet *patterns); -/// Populates the patterns that convert from StableHLO to Linalg on tensors. -/// Extends the general linalg lowering patterns with IREE-specific ones. -void populateStableHloToLinalgOnTensorsConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - /// Populates the patterns that convert from StableHLO collective ops to Flow /// ops. void populateStableHloCollectivesConversionPatterns( MLIRContext *context, TypeConverter &typeConverter, RewritePatternSet *patterns); -//===----------------------------------------------------------------------===// -// Fine-grained patterns used by the implementation. -//===----------------------------------------------------------------------===// -namespace detail { -/// Populates the patterns that convert from elementwise StableHLO ops to Linalg -/// on tensors. -void populatePointwiseStableHloToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps); - -/// Populates the patterns that convert from convolution StableHLO ops to Linalg -/// on tensors. -void populateStableHloConvolutionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - -/// Populates the patterns that convert from dot product StableHLO ops to Linalg -/// on tensors. -void populateStableHloDotProdToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - -/// Populates the patterns that convert from random number generation StableHLO -/// ops to Linalg on tensors. -void populateStableHloRandomToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); - -/// Populates the patterns that convert from reduction StableHLO ops to Linalg -/// on tensors. -void populateStableHloReductionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps); - -/// Populates the patterns that convert scalar StableHLO ops to Arith ops. -void populateScalarHloToArithConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, - llvm::function_ref filterFn = nullptr); -} // namespace detail - } // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_REWRITERS_H_ diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToArith.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToArith.cpp deleted file mode 100644 index 7a62ec2aca80..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToArith.cpp +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering scalar StableHLO ops to arith dialect. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { - -template -struct ScalarHloToFuncPatterns final : OpConversionPattern { - ScalarHloToFuncPatterns(TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit) {} - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op->getParentOp())) { - return rewriter.notifyMatchFailure(op, - "Return must be inside a function"); - } - mlir::Operation::operand_range operands = op.getOperands(); - rewriter.replaceOpWithNewOp(op, operands); - return success(); - } -}; -template -struct ScalarHloToArithmeticPattern final : OpConversionPattern { - ScalarHloToArithmeticPattern( - TypeConverter &typeConverter, MLIRContext *context, - llvm::function_ref filterFn = nullptr, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - filterFn(filterFn) {} - - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (filterFn && !filterFn(op)) - return failure(); - - auto isScalar = [](Value v) { - return cast(v.getType()).getRank() == 0; - }; - - if (!llvm::all_of(adaptor.getOperands(), isScalar)) - return rewriter.notifyMatchFailure(op, "All operands must be scalar."); - - Location loc = op.getLoc(); - - auto resultTy = dyn_cast_or_null( - this->getTypeConverter()->convertType(op->getResultTypes().front())); - if (!resultTy) - return failure(); - - SmallVector operands; - for (Value operand : adaptor.getOperands()) { - operands.push_back( - rewriter.create(loc, operand, ValueRange())); - } - Value scalarResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, resultTy.getElementType(), operands, &rewriter); - if (!scalarResult) - return failure(); - rewriter.replaceOpWithNewOp(op, resultTy, - scalarResult); - return success(); - } - -private: - llvm::function_ref filterFn; -}; - -} // namespace - -namespace detail { -void populateScalarHloToArithConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, - llvm::function_ref filterFn) { - // TODO(#12678): Handle the XLA rng op. - patterns->add< - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern, - ScalarHloToArithmeticPattern>(typeConverter, - context, filterFn); - patterns->add>( - typeConverter, context); -} -} // namespace detail - -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp index 65963904b0b7..bce2f06187e6 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp @@ -12,7 +12,6 @@ #include "compiler/plugins/input/StableHLO/Conversion/Passes.h" #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h" #include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" @@ -35,6 +34,8 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/conversions/linalg/transforms/Rewriters.h" +#include "stablehlo/conversions/linalg/transforms/TypeConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -46,7 +47,9 @@ namespace mlir::iree_compiler::stablehlo { namespace { /// Converts stablehlo.concatenate operation to extract_slice ops + insert_slice -/// ops. +/// ops. mlir::stablehlo::populateStablehloToLinalgConversionPatterns provides a +/// lowering to linalg using SCF that has numerics issues when run through IREE, +/// so we use this lowering instead with a higher pattern benefit. struct ConcatenateOpConversion final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -491,10 +494,11 @@ struct ConvertStableHloToIreeInputDialects final : impl::ConvertStableHloToIreeInputDialectsBase< ConvertStableHloToIreeInputDialects> { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert< - IREE::Flow::FlowDialect, IREE::Util::UtilDialect, linalg::LinalgDialect, - arith::ArithDialect, tensor::TensorDialect, shape::ShapeDialect, - math::MathDialect, memref::MemRefDialect, complex::ComplexDialect>(); + registry.insert(); } void runOnOperation() override { @@ -502,7 +506,7 @@ struct ConvertStableHloToIreeInputDialects final RewritePatternSet patterns(context); std::unique_ptr typeConverter = - createStableHloToLinalgTypeConverter(); + std::make_unique<::mlir::stablehlo::LinalgTypeConverter>(); typeConverter->addArgumentMaterialization(scalarToTensor); typeConverter->addSourceMaterialization(scalarToTensor); typeConverter->addTargetMaterialization(scalarToTensor); @@ -511,16 +515,23 @@ struct ConvertStableHloToIreeInputDialects final // expensive expansions. populateCanonicalizationPatterns(context, &patterns, /*benefit=*/1024); - populateStableHloToLinalgOnTensorsConversionPatterns( - context, *typeConverter, &patterns); + // Run custom patterns with a high benefit to override stablehlo patterns. + patterns.add(*typeConverter, context, + PatternBenefit{1000}); + + // Run upstream stablehlo patterns with a default benefit. + ::mlir::stablehlo::populateStablehloToLinalgConversionPatterns( + context, *typeConverter, &patterns, /*enablePrimitiveOps=*/false, + /*enableSparseOps=*/false); + + // Lowerings using IREE-specific operators (and not just common dialects + // like linalg, scf, arith, etc.). populateStableHloCollectivesConversionPatterns(context, *typeConverter, &patterns); - // TODO(#12678): Handle remaining complex ops. - // TODO(*): expose patterns that do this much better from - // iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp - + // iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp // Structural patterns (functions, cfg, terminators). patterns.add(*typeConverter, context); patterns.add(*typeConverter, context); @@ -626,17 +637,4 @@ struct ConvertStableHloToIreeInputDialects final } // namespace -void populateStableHloToLinalgOnTensorsConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - // TODO(#5809): Drop ConcatenateOp lowering in favor of the upstream version - // then remove the PatternBenefit here - patterns->add(typeConverter, context, - PatternBenefit{1000}); - - populateStableHloToLinalgConversionPatterns(context, typeConverter, patterns, - /*enablePrimitiveOps=*/false); -} - } // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp deleted file mode 100644 index c886df805ec6..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp +++ /dev/null @@ -1,2683 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO/CHLO dialects to Linalg dialect. - -#include -#include -#include - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Passes.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { - -#define GEN_PASS_DEF_CONVERTSTABLEHLOTOLINALG -#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc" - -namespace { -Value getResultValue(Operation *op) { return op->getResult(0); } - -ShapedType getHloOpResultType(Operation *op) { - return llvm::cast(getResultValue(op).getType()); -} - -/// Extracts an element from a tensor and optionally converts it to an index -/// type, based on the tensor's pre-type conversion type. -Value extractIndexFromTensor(OpBuilder &builder, Location loc, Value tensor, - ShapedType originalType, - ArrayRef tensorIndex = {}) { - Value extracted = builder.create(loc, tensor, tensorIndex); - if (extracted.getType().isIndex()) - return extracted; - return originalType.getElementType().isUnsignedInteger() - ? builder.createOrFold( - loc, builder.getIndexType(), extracted) - : builder.createOrFold( - loc, builder.getIndexType(), extracted); -} - -//===----------------------------------------------------------------------===// -// stablehlo.Einsum conversion patterns. -//===----------------------------------------------------------------------===// - -// Looks through a set of dimension that has been marked as reduction axes, -// if it is found within the set, then we set it as "reduction", otherwise -// we can label it as "parallel". -SmallVector -getEinsumLoopsAttrs(const llvm::SmallSetVector &inputInd, - const llvm::SmallSetVector &reductionDims) { - SmallVector res; - for (StringRef dim : inputInd) { - if (!reductionDims.contains(dim)) { - res.push_back(utils::IteratorType::parallel); - } else { - res.push_back(utils::IteratorType::reduction); - } - } - return res; -} - -SmallVector -extractDynamicEinsumSizes(OpBuilder &b, Location loc, Value lhs, Value rhs, - const SmallVector &lhsLoopVec, - const SmallVector &rhsLoopVec, - const SmallVector &outputLoopVec) { - SmallVector dynSizes; - for (const std::string &dimInd : outputLoopVec) { - Value dimSize; - const auto *dimIndIt = llvm::find(lhsLoopVec, dimInd); - if (dimIndIt != lhsLoopVec.end()) { - // Query from lhs vars. - auto dimIndPos = dimIndIt - lhsLoopVec.begin(); - auto lhsShape = - llvm::dyn_cast(lhs.getType()).getShape(); - if (!ShapedType::isDynamic(lhsShape[dimIndPos])) - continue; - dimSize = b.create(loc, lhs, dimIndPos); - } else { - // query from rhs vars. - dimIndIt = std::find(rhsLoopVec.begin(), rhsLoopVec.end(), dimInd); - auto dimIndPos = dimIndIt - rhsLoopVec.begin(); - auto rhsShape = - llvm::dyn_cast(rhs.getType()).getShape(); - if (!ShapedType::isDynamic(rhsShape[dimIndPos])) - continue; - dimSize = b.create(loc, rhs, dimIndPos); - } - dynSizes.push_back(dimSize); - } - return dynSizes; -} - -// Adds indices/axes that are missing from output set. -llvm::SmallSetVector -findSummationAxes(const llvm::SmallSetVector &inputSet, - const llvm::SmallSetVector &outputSet) { - llvm::SmallSetVector summationAxes; - for (StringRef ind : inputSet) { - if (!outputSet.contains(ind)) - summationAxes.insert(ind); - } - return summationAxes; -} - -// Given a 1:1 map from std::string -> affine dimension expression -// we can get the affine expression of dimensions that an -// operand will access based on the input_str of einsum_config. -// For example: -// let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2} -// for einsum_config "abc,cb->acb" -// first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2). -// second_input_operand will get umap[{"c","b"}] -> (d2, d1). -// output_operand will get umap[{"a","c","b"}] -> (d0, d2, d1). -SmallVector -getExprFromConfig(const SmallVector &loopDims, - const DenseMap &strAffineDimUmap) { - SmallVector exprs; - for (const auto &dim : loopDims) { - exprs.push_back(strAffineDimUmap.lookup(dim)); - } - return exprs; -} - -// Convert stablehlo.einsum op into linalg.generic. -// Algorithm in general 3 steps: - -// Step1) Dissect entire einsum_config to different operands -// e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}. - -// Step2) Split up the string into vector of the elements -// e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"], -// rhs:["c","d"], out:["a","b","d"]}. - -// Step3) Convert the vector into data access -// patern represented by affineMaps with affineDimensions e.g -// {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2], -// rhs:[d2,d3], out:[d0,d1,d3]}. -struct EinsumToLinalgConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::EinsumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto getRank = [](Value v) { - return llvm::cast(v.getType()).getRank(); - }; - auto einsumConfig = op.getEinsumConfig(); - - // With the assumption of binary input operand and single output - // get the inputs and output operands' indices. - // einsum_config = "lhs_loop,rhs_loop->out_loop" - std::size_t posArrow = einsumConfig.find(kArrow); - std::size_t posComma = einsumConfig.find(kComma); - - StringRef lhsLoop = einsumConfig.substr(0, posComma); - StringRef rhsLoop = einsumConfig.substr( - posComma + kComma.size(), posArrow - (posComma + kComma.size())); - StringRef outLoop = einsumConfig.substr(posArrow + kArrow.size()); - - // Check for Invalid Configs. - // 1.Check that there is only maximum 2 inputs - // 2.Check that there is only maximum 1 output - // 3.Check that there is 1 kArrow - if (rhsLoop.contains(kComma) || outLoop.contains(kComma) || - outLoop.contains(kArrow)) { - return rewriter.notifyMatchFailure(op, "Invalid einsum config!"); - } - - // Find result type, if on tensors. - auto resultTy = getTypeConverter()->convertType( - getHloOpResultType(op)); - - // Check result type compatibility. - if (!resultTy || !resultTy.getElementType().isSignlessIntOrFloat()) { - return rewriter.notifyMatchFailure(op, "Invalid result type"); - } - - // Convert the representation to vector. - SmallVector lhsEin = - getEinsumConfigAsVector(lhsLoop, getRank(adaptor.getLhs())); - SmallVector rhsEin = - getEinsumConfigAsVector(rhsLoop, getRank(adaptor.getRhs())); - SmallVector outEin = - getEinsumConfigAsVector(outLoop, resultTy.getRank()); - - if (!checkBatchHasEqualRank(lhsEin.size(), lhsLoop, rhsEin.size(), rhsLoop, - outEin.size(), outLoop)) { - return rewriter.notifyMatchFailure( - op, "Invalid elipsis('...') within einsum config!"); - } - - // Find all unique indices in the input and output. - llvm::SmallSetVector inputInd; - llvm::SmallSetVector outputInd; - - inputInd.insert(lhsEin.begin(), lhsEin.end()); - inputInd.insert(rhsEin.begin(), rhsEin.end()); - outputInd.insert(outEin.begin(), outEin.end()); - - llvm::SmallSetVector reductionAxe = - findSummationAxes(inputInd, outputInd); - - // Find input/output values and types. - Location loc = op.getLoc(); - - // Prepare init tensor for linalg.generic op. - auto dynSizes = - extractDynamicEinsumSizes(rewriter, loc, adaptor.getLhs(), - adaptor.getRhs(), lhsEin, rhsEin, outEin); - Value output = getEmptyTensor(rewriter, loc, resultTy, dynSizes); - if (!reductionAxe.empty()) { - output = fillTensorWithZeros(rewriter, loc, output); - } - - // Create indexing maps. - // Create a 1:1 map from f:strDimension -> affineDimension. - int64_t nloops = inputInd.size(); - DenseMap strAffineDimUmap; - for (auto [idx, value] : llvm::enumerate(inputInd)) { - strAffineDimUmap[value] = rewriter.getAffineDimExpr(idx); - } - - // From einsum_config of each operand in vector, generate - // the equivalent vector. - SmallVector maps; - for (const SmallVector &loopOperand : - {lhsEin, rhsEin, outEin}) { - auto exprs = getExprFromConfig(loopOperand, strAffineDimUmap); - maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext())); - } - - auto linalgOp = rewriter.create( - loc, resultTy ? resultTy : TypeRange{}, adaptor.getOperands(), output, - maps, getEinsumLoopsAttrs(inputInd, reductionAxe), - [reductionAxe](OpBuilder &b, Location nestedLoc, ValueRange args) { - Value resultVal = - b.create(nestedLoc, args[0], args[1]); - if (!reductionAxe.empty()) { - resultVal = - b.create(nestedLoc, args[2], resultVal); - } - b.create(nestedLoc, resultVal); - }, - linalg::getPrunedAttributeList(op)); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } - -private: - static constexpr StringLiteral kArrow = "->"; - static constexpr StringLiteral kComma = ","; - static constexpr StringLiteral kEllipsis = "..."; - - static bool checkBatchHasEqualRank(size_t lhsRank, StringRef lhsLoop, - size_t rhsRank, StringRef rhsLoop, - size_t outRank, StringRef outLoop); - static SmallVector getEinsumConfigAsVector(StringRef loop, - size_t operandRank); -}; - -// Convert the representation from string/vector to vector. -// i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3: -// get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"} -SmallVector -EinsumToLinalgConverter::getEinsumConfigAsVector(StringRef loop, - size_t operandRank) { - SmallVector loopDim; - size_t preElip = loop.find(kEllipsis); - bool hasElip = preElip != StringRef::npos; - if (!hasElip) - preElip = loop.size(); - // Add the dimension until the end or up to ellipsis if it exist. - for (int64_t preElipInd = 0; preElipInd < static_cast(preElip); - preElipInd++) { - loopDim.push_back(loop.substr(preElipInd, 1).str()); - } - if (!hasElip) - return loopDim; - // Case where Ellipsis presence: - size_t nonBatchRank = loop.size() - kEllipsis.size(); - size_t batchRank = operandRank - nonBatchRank; - // Add the batch dimension ("0",...,"N") where N is rank of batch into the - // loop. - for (int64_t batchInd = 0; batchInd < static_cast(batchRank); - batchInd++) { - loopDim.push_back(std::to_string(batchInd)); - } - // Add the dimension after ellipsis into the loop. - int postElip = preElip + kEllipsis.size(); - for (int64_t postElipInd = postElip; - postElipInd < static_cast(loop.size()); ++postElipInd) { - loopDim.push_back(loop.substr(postElipInd, 1).str()); - } - return loopDim; -} - -// Returns true if all operand's batch has same rank. -bool EinsumToLinalgConverter::checkBatchHasEqualRank( - size_t lhsRank, StringRef lhsLoop, size_t rhsRank, StringRef rhsLoop, - size_t outRank, StringRef outLoop) { - SmallVector batchRankVec; - if (lhsRank != lhsLoop.size()) { - size_t lhsBatchRank = lhsRank - (lhsLoop.size() - kEllipsis.size()); - batchRankVec.push_back(lhsBatchRank); - } - if (rhsRank != rhsLoop.size()) { - size_t rhsBatchRank = rhsRank - (rhsLoop.size() - kEllipsis.size()); - batchRankVec.push_back(rhsBatchRank); - } - if (outRank != outLoop.size()) { - size_t outBatchRank = outRank - (outLoop.size() - kEllipsis.size()); - batchRankVec.push_back(outBatchRank); - } - bool batchHasEqualRank = true; - - // Condition is valid if only 1 operand or less have batches. - if (batchRankVec.size() < 2) - return batchHasEqualRank; - - if (!llvm::all_equal(batchRankVec)) - return false; - - return batchHasEqualRank; -} - -/// Base class for lowering HLO operations that have one operand and one result, -/// and are semantically equivalent to a copy of the input to the output (like -/// transpose, some reshape, etc.). The derived classes need to provide a method -/// `getIndexingMaps` that returns AffineMaps for the index maps of the input -/// and the output. -template -struct DataMovementOpConverter : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - ShapedType resultType = getHloOpResultType(op); - resultType = - this->getTypeConverter()->template convertType(resultType); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - SmallVector indexingMaps = - Derived::getIndexingMaps(op, &rewriter); - if (indexingMaps.empty()) - return failure(); - - int64_t nloops = resultType.getRank(); - Location loc = op.getLoc(); - auto linalgOp = rewriter.create( - loc, - /*resultTensorTypes=*/resultType, - /*inputs=*/adaptor.getOperands().front(), - /*outputBuffers=*/ - - ValueRange{getEmptyTensorFor(rewriter, loc, resultType, op, - adaptor.getOperands())}, - indexingMaps, getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - nestedBuilder.create(loc, *args.begin()); - }, - linalg::getPrunedAttributeList(op)); - rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); - return success(); - } -}; - -/// Pattern to convert BroadcastOp to Linalg ops. -template -struct BroadcastConverter final - : DataMovementOpConverter, OpTy> { - using DataMovementOpConverter::DataMovementOpConverter; - - static SmallVector getIndexingMaps(OpTy broadcastOp, - Builder *b) { - ShapedType inputType = - llvm::cast(broadcastOp.getOperand().getType()); - unsigned inputRank = inputType.getRank(); - unsigned nloops = getHloOpResultType(broadcastOp).getRank(); - - // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to - // the input's dimensions. - unsigned numPrependedDims = llvm::size(broadcastOp.getBroadcastSizes()); - SmallVector inputDimExprs; - inputDimExprs.reserve(inputRank); - for (unsigned i = 0; i < inputRank; ++i) { - inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); - } - - AffineMap inputMap; - MLIRContext *context = b->getContext(); - if (inputDimExprs.empty()) { - // The input is a scalar, i.e. this is a scalar broadcast op. - inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); - } else { - inputMap = - AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); - } - return {inputMap, b->getMultiDimIdentityMap(nloops)}; - } -}; - -struct BroadcastOpToBroadcastConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::BroadcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - int64_t numPrependedDims = op.getBroadcastSizes().size(); - SmallVector dimensions = - llvm::to_vector(llvm::seq(0, numPrependedDims)); - - Location loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - rewriter.replaceOpWithNewOp( - op, op.getOperand(), emptyTensor, dimensions, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct HloBroadcastInDimConverter final - : DataMovementOpConverter { - using DataMovementOpConverter::DataMovementOpConverter; - - static SmallVector - getIndexingMaps(mlir::stablehlo::BroadcastInDimOp broadcastOp, Builder *b) { - ShapedType resultType = getHloOpResultType(broadcastOp); - auto operandType = cast(broadcastOp.getOperand().getType()); - unsigned nloops = resultType.getRank(); - - // The input is a scalar, i.e. this is a scalar broadcast op. - if (operandType.getRank() == 0) { - return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } - - ArrayRef operandShape = operandType.getShape(); - SmallVector dimExprs; - dimExprs.reserve(nloops); - - for (auto [idx, size] : - llvm::enumerate(broadcastOp.getBroadcastDimensions())) { - bool expansionNeeded = - operandShape[idx] == 1 && resultType.getShape()[size] != 1; - dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0) - : b->getAffineDimExpr(size)); - } - return { - AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } -}; - -Value collapseExpandingDims(PatternRewriter &rewriter, Location loc, - Value operand, SmallVector &dimensions, - llvm::function_ref isExpandingDim) { - auto operandTy = llvm::cast(operand.getType()); - - SmallVector reassociationMap; - ReassociationIndices currentIndices; - - ArrayRef operandShape = operandTy.getShape(); - SmallVector newOperandShape; - SmallVector newDimensions; - - for (auto [idx, dim] : llvm::enumerate(dimensions)) { - currentIndices.push_back(idx); - - if (!isExpandingDim(idx)) { - reassociationMap.push_back(currentIndices); - currentIndices.clear(); - newOperandShape.push_back(operandShape[idx]); - newDimensions.push_back(dim); - } - } - - if (!reassociationMap.empty()) { - reassociationMap.back().insert(reassociationMap.back().end(), - currentIndices.begin(), - currentIndices.end()); - } - - if (dimensions.size() != newDimensions.size()) { - dimensions = newDimensions; - - auto newOperandType = - RankedTensorType::get(newOperandShape, operandTy.getElementType()); - operand = rewriter.create( - loc, newOperandType, operand, reassociationMap); - } - return operand; -} - -// Insert linalg.transpose if broadcasted dimensions are not in sorted order. -// linalg.broadcast does not support implicit transpose, so the input needs to -// be explicitly transposed. -Value transposeBroadcastOperand(PatternRewriter &rewriter, Location loc, - Value operand, - SmallVector &dimensions) { - // Do not insert `transpose` is dimensions are already sorted. - if (llvm::is_sorted(dimensions)) - return operand; - - SmallVector permutation = - llvm::to_vector(llvm::seq(0, dimensions.size())); - llvm::sort(permutation, [&](int64_t lhs, int64_t rhs) { - return dimensions[lhs] < dimensions[rhs]; - }); - - auto operandTy = llvm::cast(operand.getType()); - ArrayRef operandShape = operandTy.getShape(); - SmallVector transposedOperandShape, transposedDimensions; - - for (int64_t index : permutation) { - transposedOperandShape.push_back(operandShape[index]); - transposedDimensions.push_back(dimensions[index]); - } - dimensions = transposedDimensions; - - return rewriter.create( - loc, - RankedTensorType::get(transposedOperandShape, operandTy.getElementType()), - operand, rewriter.getDenseI64ArrayAttr(permutation)); -} - -struct BroadcastInDimOpToBroadcastConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - SmallVector broadcastDimensions = - llvm::to_vector(op.getBroadcastDimensions()); - - Value operand = adaptor.getOperand(); - auto operandTy = llvm::cast(operand.getType()); - auto resultTy = - llvm::cast(typeConverter->convertType(op.getType())); - - ArrayRef operandShape = operandTy.getShape(); - ArrayRef resultShape = resultTy.getShape(); - - operand = collapseExpandingDims( - rewriter, loc, operand, broadcastDimensions, [&](int64_t i) { - return operandShape[i] == 1 && - resultShape[broadcastDimensions[i]] != 1; - }); - operand = - transposeBroadcastOperand(rewriter, loc, operand, broadcastDimensions); - - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - SmallVector addedDimensions; - for (int64_t dim : llvm::seq(0, resultTy.getRank())) { - if (!llvm::is_contained(broadcastDimensions, dim)) - addedDimensions.push_back(dim); - } - - rewriter.replaceOpWithNewOp( - op, operand, emptyTensor, addedDimensions, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -// If the input has a static shape we know exactly when the broadcast must -// expand (the dimension is 1, which also trivially expands to 1) or will never -// expand (the dimension is not 1). We can also source the information from the -// optionally provided attributes on statically known broadcasting behavior. -// This means we can lower the broadcast just as we would lower a fully static -// broadcast and go directly to `linalg.generic`. - -// This also covers the important case of broadcasting a scalar. Ideally the -// pattern (`stablehlo.constant` -> `stablehlo.dynamic_broadcast_in_dim`) should -// be converted to a tensor dialect op similar to TF's `ConstantLikeOp`. -struct HloDynamicBroadcastInDimConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value operand = adaptor.getOperand(); - auto operandType = dyn_cast(operand.getType()); - if (!operandType) - return failure(); - auto resultType = - getTypeConverter()->convertType(op.getType()); - if (!resultType) - return failure(); - - // Determine dimension expressions based on whether the dimension is - // expanding (0) or non-expanding (identity), and fail if we cannot decide - // this. - SmallVector dimExprs(operandType.getRank(), nullptr); - - // Use static type info. - auto bcastDims = - llvm::map_to_vector(op.getBroadcastDimensions(), - [](int64_t d) { return static_cast(d); }); - for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) { - if (ShapedType::isDynamic(dim)) - continue; - - bool isExpanding = dim == 1; - dimExprs[idx] = isExpanding ? rewriter.getAffineConstantExpr(0) - : rewriter.getAffineDimExpr(bcastDims[idx]); - } - - // Use annotated expansion behavior, if available. - if (auto dims = op.getKnownExpandingDimensions()) { - for (int i : *dims) { - dimExprs[i] = rewriter.getAffineConstantExpr(0); - } - } - if (auto dims = op.getKnownNonexpandingDimensions()) { - for (int i : *dims) { - dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]); - } - } - - // Fail if unknown expansion behavior remains. - if (!llvm::all_of(dimExprs, [](AffineExpr expr) { return expr; })) - return failure(); - - // Materialize `linalg.generic` op. - Location loc = op.getLoc(); - int64_t nloops = resultType.getRank(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - rewriter.replaceOpWithNewOp( - op, TypeRange{emptyTensor.getType()}, ValueRange{operand}, - /*outputBuffers=*/ValueRange{emptyTensor}, - llvm::ArrayRef({AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, - dimExprs, rewriter.getContext()), - rewriter.getMultiDimIdentityMap(nloops)}), - getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - nestedBuilder.create(loc, *args.begin()); - }, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct DynamicBroadcastInDimOpToBroadcastConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - Value operand = adaptor.getOperand(); - auto operandTy = llvm::dyn_cast(operand.getType()); - if (!operandTy) - return failure(); - auto resultTy = - getTypeConverter()->convertType(op.getType()); - if (!resultTy) - return failure(); - - SmallVector broadcastDimensions = - llvm::to_vector(op.getBroadcastDimensions()); - - SmallVector> expansionBehavior( - broadcastDimensions.size()); - - // Use static type info. - for (auto [idx, dim] : llvm::enumerate(operandTy.getShape())) { - if (ShapedType::isDynamic(dim)) - continue; - expansionBehavior[idx] = (dim == 1); - } - - // Use annotated expansion behavior, if available. - if (op.getKnownExpandingDimensions()) { - auto dims = op.getKnownExpandingDimensions().value(); - for (int it : dims) { - expansionBehavior[it] = true; - } - } - if (op.getKnownNonexpandingDimensions()) { - auto dims = op.getKnownNonexpandingDimensions().value(); - for (int it : dims) { - expansionBehavior[it] = false; - } - } - - // Fail if unknown expansion behavior remains. - if (llvm::any_of(expansionBehavior, [](auto v) { return !v.has_value(); })) - return failure(); - - auto isExpandingDim = [&](int64_t i) { - return expansionBehavior[i].value(); - }; - - // Use attribute information to insert 1s into operand type. - operand = getBroadcastOperand(rewriter, loc, operand, isExpandingDim); - - auto broadcastResultTy = getBroadcastResultType( - operand, resultTy, broadcastDimensions, isExpandingDim); - - operand = collapseExpandingDims(rewriter, loc, operand, broadcastDimensions, - isExpandingDim); - operand = - transposeBroadcastOperand(rewriter, loc, operand, broadcastDimensions); - - Value emptyTensor = getEmptyTensorFor(rewriter, loc, broadcastResultTy, op, - adaptor.getOperands()); - - SmallVector addedDimensions; - for (int64_t dim : llvm::seq(0, resultTy.getRank())) { - if (!llvm::is_contained(broadcastDimensions, dim)) - addedDimensions.push_back(dim); - } - - Value result = rewriter - .create( - loc, operand, emptyTensor, addedDimensions, - linalg::getPrunedAttributeList(op)) - .getResults()[0]; - - if (resultTy != broadcastResultTy) { - result = rewriter.create(loc, resultTy, result); - } - - rewriter.replaceOp(op, result); - return success(); - } - -private: - static Value - getBroadcastOperand(PatternRewriter &rewriter, Location loc, Value operand, - llvm::function_ref isExpandingDim) { - auto operandTy = llvm::dyn_cast(operand.getType()); - - SmallVector updatedOperandShape = - llvm::to_vector(operandTy.getShape()); - for (auto [idx, dim] : llvm::enumerate(updatedOperandShape)) { - if (isExpandingDim(idx)) - dim = 1; - } - - auto updatedOperandTy = - RankedTensorType::get(updatedOperandShape, operandTy.getElementType()); - - if (updatedOperandTy != operandTy) { - operand = rewriter.create(loc, updatedOperandTy, operand); - } - - return operand; - } - - static ShapedType - getBroadcastResultType(Value operand, RankedTensorType resultTy, - ArrayRef dimensions, - llvm::function_ref isExpandingDim) { - auto operandShape = - llvm::cast(operand.getType()).getShape(); - auto broadcastResultShape = llvm::to_vector(resultTy.getShape()); - - for (auto [operandIndex, resultIndex] : llvm::enumerate(dimensions)) { - if (isExpandingDim(operandIndex)) - continue; - broadcastResultShape[resultIndex] = operandShape[operandIndex]; - } - - return RankedTensorType::get(broadcastResultShape, - resultTy.getElementType()); - } -}; - -template -struct TransposeConverter final - : DataMovementOpConverter, OpTy> { - using DataMovementOpConverter, - OpTy>::DataMovementOpConverter; - - static SmallVector getIndexingMaps(OpTy op, Builder *b) { - auto resultType = llvm::cast(getHloOpResultType(op)); - int64_t nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.resize(resultType.getRank()); - for (auto [idx, value] : llvm::enumerate(op.getPermutation())) { - inputExprs[value] = b->getAffineDimExpr(idx); - } - return { - AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } -}; - -struct TransposeOpToTransposeConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::TransposeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - Location loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - auto permutation = - dyn_cast_or_null(op.getPermutationAttr()); - - rewriter.replaceOpWithNewOp( - op, adaptor.getOperand(), emptyTensor, permutation, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct BitcastConvertConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::BitcastConvertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - auto inputType = - llvm::cast(adaptor.getOperand().getType()); - auto outputType = - getTypeConverter()->convertType(op.getType()); - if (!outputType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - Location loc = op.getLoc(); - - // Fallback to pointwise conversion if the tensor dimensions are not - // changing. - if (inputType.getRank() == outputType.getRank()) { - return failure(); - } - - auto inputBitWidth = inputType.getElementType().getIntOrFloatBitWidth(); - auto outputBitWidth = outputType.getElementType().getIntOrFloatBitWidth(); - - auto maxRank = std::max(inputType.getRank(), outputType.getRank()); - auto identityMap = - AffineMap::getMultiDimIdentityMap(maxRank, rewriter.getContext()); - AffineMap indexingMaps[] = { - AffineMap::get( - /*dimCount=*/maxRank, /*symbolCount=*/0, - identityMap.getResults().take_front(inputType.getRank()), - rewriter.getContext()), - AffineMap::get( - /*dimCount=*/maxRank, /*symbolCount=*/0, - identityMap.getResults().take_front(outputType.getRank()), - rewriter.getContext())}; - - Value output = - getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); - bool isExpansion = inputBitWidth > outputBitWidth; - bool isContraction = inputBitWidth < outputBitWidth; - // When combining values we start with a 0 and merge bits into it. - if (isContraction) { - output = fillTensorWithZeros(rewriter, loc, output); - } - - rewriter.replaceOpWithNewOp( - op, outputType, adaptor.getOperand(), output, indexingMaps, - getParallelAndReductionIterators(maxRank, isContraction ? 1 : 0), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - auto inIntType = nestedBuilder.getIntegerType(inputBitWidth); - auto outIntType = nestedBuilder.getIntegerType(outputBitWidth); - Value innerResult = args.front(); - if (isExpansion) { - // Expand a big value into multiple small values with shifts. - auto iotaIndex = - nestedBuilder.create(nestedLoc, maxRank - 1); - auto iota = nestedBuilder.create( - nestedLoc, inIntType, iotaIndex); - - auto width = nestedBuilder.create( - nestedLoc, - nestedBuilder.getIntegerAttr(inIntType, outputBitWidth)); - auto shiftWidth = - nestedBuilder.create(nestedLoc, iota, width); - Value inputCasted = nestedBuilder.create( - nestedLoc, inIntType, args.front()); - Value shifted = nestedBuilder.create( - nestedLoc, inputCasted, shiftWidth); - innerResult = nestedBuilder.create( - nestedLoc, outIntType, shifted); - } else if (isContraction) { - // Combine multiple small values into one big value. - auto iotaIndex = - nestedBuilder.create(nestedLoc, maxRank - 1); - auto iota = nestedBuilder.create( - nestedLoc, outIntType, iotaIndex); - - auto width = nestedBuilder.create( - nestedLoc, - nestedBuilder.getIntegerAttr(outIntType, inputBitWidth)); - auto shiftWidth = - nestedBuilder.create(nestedLoc, iota, width); - Value inputCasted = nestedBuilder.create( - nestedLoc, inIntType, args.front()); - Value inputExt = nestedBuilder.create( - nestedLoc, outIntType, inputCasted); - Value shifted = nestedBuilder.create( - nestedLoc, inputExt, shiftWidth); - Value accumulatorCasted = nestedBuilder.create( - nestedLoc, outIntType, args.back()); - innerResult = nestedBuilder.create( - nestedLoc, outIntType, shifted, accumulatorCasted); - } - innerResult = nestedBuilder.create( - nestedLoc, outputType.getElementType(), innerResult); - nestedBuilder.create(nestedLoc, innerResult); - }, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -// Lowers stablehlo.RealDynamicSliceOp to tensor.extract_slice and other -// arith/tensor dialect ops. -struct RealDynamicSliceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - // Computes size of a slice as - // size = ceil((limit - start)/stride) - static Value computeSize(Location loc, Value start, Value limit, Value stride, - ConversionPatternRewriter &b) { - Value delta = b.create(loc, limit, start); - Value ret = b.create(loc, delta, stride); - if (ret.getType().isIndex()) - return ret; - return b.create(loc, b.getIndexType(), ret); - } - - LogicalResult - matchAndRewrite(mlir::stablehlo::RealDynamicSliceOp realDynamicSliceOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = realDynamicSliceOp.getLoc(); - auto argType = llvm::dyn_cast(adaptor.getOperand().getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(realDynamicSliceOp, - "require known-rank args"); - } - - Type dimElementType = getElementTypeOrSelf(adaptor.getStartIndices()); - if (getElementTypeOrSelf(adaptor.getLimitIndices()) != dimElementType || - getElementTypeOrSelf(adaptor.getStrides()) != dimElementType) { - return rewriter.notifyMatchFailure( - realDynamicSliceOp, - "requires same element type for all dimension specification"); - } - Type arithType = - dimElementType.isIndex() ? rewriter.getI64Type() : dimElementType; - Type indexType = rewriter.getIndexType(); - - auto resultType = llvm::cast( - this->typeConverter->convertType(realDynamicSliceOp.getType())); - Value zero = rewriter.create(loc, 0); - SmallVector offsets, sizes, strides; - SmallVector clampType(3, arithType); - for (auto i : llvm::seq(0, argType.getRank())) { - Value dim = rewriter.create(loc, i); - Value start = rewriter.create( - loc, adaptor.getStartIndices(), dim); - Value limit = rewriter.create( - loc, adaptor.getLimitIndices(), dim); - Value stride = - rewriter.create(loc, adaptor.getStrides(), dim); - - // Compute i-th dimension size of the result : size[i]. - // If the i-th dimension of the result type is known, we go ahead with it - // else we compute it using limit, start and stride values. - int64_t resultDimSize = resultType.getDimSize(i); - Value size = - ShapedType::isDynamic(resultDimSize) - ? computeSize(loc, start, limit, stride, rewriter) - : rewriter.create(loc, resultDimSize); - - // We can now convert start to index. - if (!start.getType().isIndex()) - start = rewriter.create( - loc, rewriter.getIndexType(), start); - - // Fetch i-th dimension size of the operand and calculate upper bound as - // ub = operand_dim[i] - size[i] - Value operandDimSize = - rewriter.createOrFold(loc, adaptor.getOperand(), dim); - Value upperBound = - rewriter.createOrFold(loc, operandDimSize, size); - - // We clamp the start_index to keep it bounded as - // 0 <= start_index[i] <= ub - // Clamp does not support index type, so cast to integer type. - start = rewriter.create(loc, start, zero); - start = rewriter.create(loc, start, upperBound); - - offsets.push_back(start); - if (ShapedType::isDynamic(resultDimSize)) - sizes.push_back(size); - else - sizes.push_back(IntegerAttr::get(indexType, resultDimSize)); - - if (!stride.getType().isIndex()) - stride = - rewriter.createOrFold(loc, indexType, stride); - strides.push_back(stride); - } - - rewriter.replaceOpWithNewOp( - realDynamicSliceOp, resultType, adaptor.getOperand(), offsets, sizes, - strides); - return success(); - } -}; - -// Converts reshape ops that can be proven to be either a collapse of dimensions -// or expansion of dimensions of the operand. -struct ReshapeOpConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReshapeOp reshapeOp, - mlir::stablehlo::ReshapeOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(reshapeOp))) - return failure(); - Value operand = adaptor.getOperand(); - auto operandType = llvm::cast(operand.getType()); - Type elemType = operandType.getElementType(); - auto resultType = llvm::cast(reshapeOp.getType()); - - if (!resultType.hasStaticShape()) - return failure(); - - // If any of the output dimensions is 0, the tensor has no elements. In that - // case, we can just replace the reshape with an empty op. - if (llvm::is_contained(resultType.getShape(), 0)) { - rewriter.replaceOpWithNewOp( - reshapeOp, resultType.getShape(), elemType); - return success(); - } - - resultType = getTypeConverter()->convertType(resultType); - if (!resultType) - return rewriter.notifyMatchFailure(reshapeOp, "type conversion failed"); - - // Special case where the result is a scalar. - if (resultType.getRank() == 0 && !operandType.hasStaticShape()) { - // This means all dimensions of the operand need to be 1. We add a cast to - // cast the dynamic dimensions to 1. - auto staticType = RankedTensorType::get( - llvm::SmallVector(operandType.getRank(), 1), elemType); - operand = rewriter.create(reshapeOp.getLoc(), staticType, - operand); - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, operand, ArrayRef{}); - return success(); - } - - // Compute the reassociation maps for the linalg operation. This will - // succeed if the reshape can be done with a single expand_shape or - // collapse_shape. - if (std::optional> reassociationMap = - getReassociationIndicesForReshape(operandType, resultType)) { - if (resultType.getRank() < operandType.getRank()) { - // We have found a working reassociation map. If the operand is dynamic, - // we first need to cast all unknown dimensions in the input that get - // collapsed to a static-sized dimension in the output, to 1. - SmallVector shape(operandType.getShape().begin(), - operandType.getShape().end()); - for (auto [idx, dims] : llvm::enumerate(*reassociationMap)) { - // If the result dim is dynamic, we do not mind dynamic entries in the - // source. - if (resultType.isDynamicDim(idx)) - continue; - for (auto targetDim : dims) { - if (ShapedType::isDynamic(shape[targetDim])) - shape[targetDim] = 1; - } - } - // Insert a cast if types are not the same (ignoring sparse encoding). - auto enc = sparse_tensor::getSparseTensorEncoding(operandType); - auto newOperandType = RankedTensorType::get(shape, elemType, enc); - if (newOperandType != operandType) { - operand = rewriter.create(reshapeOp.getLoc(), - newOperandType, operand); - } - // Generate collapse operation. - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, operand, *reassociationMap); - } else { - // Generate expand operation. - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, operand, *reassociationMap); - } - return success(); - } - - Value collapsedOp = operand; - Location loc = reshapeOp.getLoc(); - auto getIdentityExprs = [&rewriter](int64_t n) { - SmallVector exprs; - for (int i = 0; i < n; ++i) - exprs.push_back(rewriter.getAffineDimExpr(i)); - return exprs; - }; - // Otherwise, we need to first reduce all source dimensions into one and - // then expand to the destination dimensions. If there is only a single - // source dimension, the reduce step can be skipped. TensorCollapseShape - // expects a different rank of operand and result. - if (operandType.getRank() != 1) { - SmallVector collapsingMap = { - // Use operand_type here because we need to collapse all operands - // dimensions. - getIdentityExprs(operandType.getRank())}; - - collapsedOp = - rewriter.create(loc, operand, collapsingMap); - } - // Cast to a known static type if the input has dynamic dimensions. - int64_t totalElems = resultType.getNumElements(); - auto collapsedType = RankedTensorType::get({totalElems}, elemType); - collapsedOp = - rewriter.create(loc, collapsedType, collapsedOp); - if (resultType.getRank() == 1) { - rewriter.replaceOp(reshapeOp, collapsedOp); - } else { - SmallVector expandingMap = { - // Use resultType here because we need to expand to all result - // dimensions. - getIdentityExprs(resultType.getRank())}; - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, collapsedOp, expandingMap); - } - return success(); - } -}; - -template -struct IotaConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using Adaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy iotaOp, Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ShapedType resultShapedType = getHloOpResultType(iotaOp); - if (!resultShapedType) - return failure(); - - resultShapedType = - this->getTypeConverter()->template convertType( - resultShapedType); - if (!resultShapedType) - return rewriter.notifyMatchFailure(iotaOp, "type conversion failed"); - - Type resultElementType = resultShapedType.getElementType(); - - // Construct the indexing maps needed for linalg.generic ops. - unsigned nloops = resultShapedType.getRank(); - - Location loc = iotaOp.getLoc(); - auto linalgOp = rewriter.create( - loc, - /*resultTensorTypes=*/ - ArrayRef{resultShapedType}, - /*inputs=*/ValueRange{}, - /*outputBuffers=*/ - - ValueRange{getEmptyTensorFor(rewriter, loc, resultShapedType, iotaOp, - adaptor.getOperands())}, - llvm::ArrayRef(rewriter.getMultiDimIdentityMap(nloops)), - getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange /*args*/) { - Value indexOp = nestedBuilder.create( - nestedLoc, iotaOp.getIotaDimension()); - Type unwrappedResultElementType = resultElementType; - if (auto complexType = - llvm::dyn_cast(unwrappedResultElementType)) - unwrappedResultElementType = complexType.getElementType(); - Value castOp = nestedBuilder.create( - nestedLoc, - nestedBuilder.getIntegerType( - unwrappedResultElementType.getIntOrFloatBitWidth()), - indexOp); - castOp = mlir::stablehlo::StableHloOpToStdScalarOp::mapOpOfType< - mlir::stablehlo::ConvertOp>(nestedLoc, resultElementType, - castOp.getType(), {castOp}, - &nestedBuilder); - nestedBuilder.create(nestedLoc, castOp); - }, - linalg::getPrunedAttributeList(iotaOp)); - rewriter.replaceOp(iotaOp, linalgOp.getResultTensors()); - return success(); - } -}; - -template -struct IotaToMapConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using Adaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy iotaOp, Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ShapedType resultTy = getHloOpResultType(iotaOp); - if (!resultTy) - return failure(); - - resultTy = - this->getTypeConverter()->template convertType(resultTy); - if (!resultTy) - return rewriter.notifyMatchFailure(iotaOp, "type conversion failed"); - - Location loc = iotaOp.getLoc(); - Value empty = getEmptyTensorFor(rewriter, loc, resultTy, iotaOp, - adaptor.getOperands()); - - auto linalgOp = rewriter.create( - loc, ValueRange{}, empty, - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange /*args*/) { - Value index = nestedBuilder.create( - nestedLoc, iotaOp.getIotaDimension()); - index = nestedBuilder.create( - nestedLoc, nestedBuilder.getI64Type(), index); - Value result = mlir::stablehlo::StableHloOpToStdScalarOp::mapOpOfType< - mlir::stablehlo::ConvertOp>(nestedLoc, resultTy.getElementType(), - index.getType(), {ValueRange{index}}, - &nestedBuilder); - nestedBuilder.create(nestedLoc, ValueRange{result}); - }, - linalg::getPrunedAttributeList(iotaOp)); - rewriter.replaceOp(iotaOp, linalgOp.getResult()); - return success(); - } -}; - -/// Converts stablehlo.concatenate operation to a linalg.generic op. -struct ConcatenateConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConcatenateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Shortcut the one-operand case, simplifies code below. - if (adaptor.getOperands().size() == 1) { - rewriter.replaceOp(op, adaptor.getOperands()[0]); - return success(); - } - - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - uint64_t dim = op.getDimension(); - Location loc = op.getLoc(); - Value zero = rewriter.create(loc, 0); - - // Allocate the output tensor with tensor.empty. - Value result = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - - // Generate a generic op to gather the elements of the concatenate. This is - // awkward standalone but allows fusion with other generic ops. - int64_t nloops = resultType.getRank(); - rewriter.replaceOpWithNewOp( - op, - /*resultTensorTypes=*/resultType, - /*inputs=*/ValueRange{}, /*outputBuffers=*/result, - llvm::ArrayRef(rewriter.getMultiDimIdentityMap(nloops)), - getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location loc, ValueRange) { - OpBuilder b = nestedBuilder; - Value concatDimSize = zero; - Value result; - - SmallVector extractIndices; - extractIndices.reserve(nloops); - for (int64_t i = 0; i < nloops; i++) { - extractIndices.push_back(b.create(loc, i)); - } - - Value indexOp = b.create(loc, dim); - for (auto [idx, arg] : llvm::enumerate(adaptor.getOperands())) { - Value newConcatDimSize; - scf::IfOp ifOp; - if (idx + 1 != adaptor.getOperands().size()) { - // Calculate how far along we have iterated along the concatenate - // dimension. That way we can tell which input to select. - newConcatDimSize = b.create( - loc, concatDimSize, b.create(loc, arg, dim)); - Value cmp = b.create(loc, rewriter.getI1Type(), - arith::CmpIPredicate::ult, - indexOp, newConcatDimSize); - ifOp = b.create(loc, resultType.getElementType(), cmp, - true); - if (result) { - b.create(loc, ifOp->getResults()[0]); - } else { - result = ifOp->getResults()[0]; - } - - b = ifOp.getThenBodyBuilder(b.getListener()); - } - - // Now adjust the index for the concatenated dimension to fit into - // the selected tensor and do an extract at that position. - extractIndices[dim] = - b.create(loc, indexOp, concatDimSize); - Value extract = - b.create(loc, arg, extractIndices); - b.create(loc, extract); - - if (ifOp) { - b = ifOp.getElseBodyBuilder(b.getListener()); - concatDimSize = newConcatDimSize; - } - } - nestedBuilder.create(loc, result); - }, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct ConstConverterTensor final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConstantOp constOp, OpAdaptor /*adaptor*/, - ConversionPatternRewriter &rewriter) const override { - auto replacementType = - getTypeConverter()->convertType(constOp.getType()); - if (!replacementType) - return rewriter.notifyMatchFailure(constOp, "type conversion failed"); - - ElementsAttr replacementAttr = constOp.getValue(); - if (replacementType == constOp.getType()) { - rewriter.replaceOpWithNewOp(constOp, replacementType, - replacementAttr); - return success(); - } else { - auto denseAttr = dyn_cast(constOp.getValue()); - if (!denseAttr) { - return rewriter.notifyMatchFailure( - constOp, - "DenseElementsAttr cast failed (only DenseElementsAttr supported)"); - } - // Signedness conversion. - // TODO(#15442): Add generic mapping utility, so we aren't limited to - // supporting only DenseElementsAttr. - replacementAttr = denseAttr.mapValues(replacementType.getElementType(), - [](const APInt &i) { return i; }); - rewriter.replaceOpWithNewOp(constOp, replacementType, - replacementAttr); - return success(); - } - } -}; - -// TODO(b/156787842): Support the lowering for dynamic shapes. -struct ReverseConverter final - : DataMovementOpConverter { - using DataMovementOpConverter::DataMovementOpConverter; - - static SmallVector - getIndexingMaps(mlir::stablehlo::ReverseOp op, Builder *b) { - auto resultType = llvm::cast(getHloOpResultType(op)); - int64_t nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.reserve(nloops); - for (int64_t i = 0; i < nloops; ++i) - inputExprs.push_back(b->getAffineDimExpr(i)); - for (int i : op.getDimensions()) { - if (resultType.isDynamicDim(i)) - return {}; - int n = resultType.getShape()[i]; - inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; - } - return { - AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), - b->getMultiDimIdentityMap(nloops)}; - } -}; - -struct SliceConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::SliceOp sliceOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto argType = - llvm::dyn_cast(adaptor.getOperands()[0].getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args"); - } - - SmallVector offsets, sizes, strides; - auto startIndices = sliceOp.getStartIndices(); - auto limitIndices = sliceOp.getLimitIndices(); - auto sliceStrides = sliceOp.getStrides(); - - for (int64_t i = 0, e = argType.getRank(); i < e; ++i) { - int64_t start = startIndices[i]; - int64_t limit = limitIndices[i]; - int64_t stride = sliceStrides[i]; - offsets.push_back(rewriter.getI64IntegerAttr(start)); - // Say that there are k elements in total, we have condition: - // start + (k - 1) * strides <= limit - 1 - // -> - // k <= (limit - 1 - start + strides) / strides - sizes.push_back( - rewriter.getI64IntegerAttr((limit - 1 - start + stride) / stride)); - strides.push_back(rewriter.getI64IntegerAttr(stride)); - } - rewriter.replaceOpWithNewOp( - sliceOp, adaptor.getOperands()[0], offsets, sizes, strides); - return success(); - } -}; - -struct DynamicSliceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicSliceOp dynamicSliceOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = dynamicSliceOp.getLoc(); - auto argType = llvm::dyn_cast(adaptor.getOperand().getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(dynamicSliceOp, - "require known-rank args"); - } - - auto resultType = getTypeConverter()->convertType( - dynamicSliceOp.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(dynamicSliceOp, - "type conversion failed"); - - SmallVector startIndices, sizes; - auto originalStartIndexType = llvm::cast( - dynamicSliceOp.getStartIndices().front().getType()); - for (auto [idx, start, size] : llvm::enumerate( - adaptor.getStartIndices(), dynamicSliceOp.getSliceSizes())) { - sizes.push_back(rewriter.getI64IntegerAttr(size)); - - // By stablehlo.DynamicSlice definition: - // `start_indices[i] = clamp(start_indices[i], - // 0, operand.dimension_size[i] - size_indices[i])` - Value startIndex = - extractIndexFromTensor(rewriter, loc, start, originalStartIndexType); - - Value mn = rewriter.create(loc, 0); - - Value mx = - rewriter.createOrFold(loc, adaptor.getOperand(), idx); - mx = rewriter.createOrFold( - loc, mx, rewriter.create(loc, size)); - - startIndex = rewriter.create(loc, startIndex, mn); - startIndex = rewriter.create(loc, startIndex, mx); - - startIndices.push_back(startIndex); - } - - int64_t rank = argType.getRank(); - SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); - - rewriter.replaceOpWithNewOp( - dynamicSliceOp, resultType, adaptor.getOperand(), startIndices, sizes, - strides); - return success(); - } -}; - -struct DynamicUpdateSliceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto operandType = - llvm::dyn_cast(adaptor.getOperand().getType()); - if (!operandType || !operandType.hasStaticShape()) { - return rewriter.notifyMatchFailure( - op, "require static ranked type for operand"); - } - - auto updateType = - llvm::dyn_cast(adaptor.getUpdate().getType()); - if (!updateType || !updateType.hasStaticShape()) { - return rewriter.notifyMatchFailure( - op, "require static ranked type for operand"); - } - - // We do not have to clamp sizes because the semantic of `update` - // guarantees that it is always in the bounds. See - // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice - SmallVector sizes; - for (int64_t size : updateType.getShape()) { - sizes.push_back(rewriter.getIndexAttr(size)); - } - - SmallVector startIndices; - Value zero = rewriter.create(loc, 0); - for (auto [idx, start] : llvm::enumerate(adaptor.getStartIndices())) { - // By stablehlo.DynamicUpdateSlice definition: - // `start_indices[i] = clamp(start_indices[i], - // 0, operand.dimension_size[i] - update.dimension_size[i])` - Value startIndex = extractIndexFromTensor( - rewriter, loc, start, - cast(op.getStartIndices()[idx].getType())); - Value ub = rewriter.create( - loc, operandType.getDimSize(idx) - updateType.getDimSize(idx)); - - startIndex = rewriter.create(loc, startIndex, zero); - startIndex = rewriter.create(loc, startIndex, ub); - startIndices.push_back(startIndex); - } - - int64_t rank = operandType.getRank(); - SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); - rewriter.replaceOpWithNewOp( - op, adaptor.getUpdate(), adaptor.getOperand(), startIndices, sizes, - strides); - return success(); - } -}; - -struct MapOpToGenericConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::MapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - assert(op.getDimensions().size() == resultType.getRank() && - "Expected a pointwise map"); - - Location loc = op.getLoc(); - Value output = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - SmallVector indexingMaps( - op.getNumOperands() + 1, - rewriter.getMultiDimIdentityMap(resultType.getRank())); - - auto linalgOp = rewriter.create( - loc, resultType, adaptor.getOperands(), output, indexingMaps, - getNParallelLoopsAttrs(resultType.getRank()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. We scalarize the operands and add a - // scalar operand representing the output tensor. - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getComputation(), region, region.end()); - TypeConverter::SignatureConversion signatureConverter(op.getNumOperands() + - 1); - - for (auto [idx, operand] : llvm::enumerate(op.getOperands())) { - Type convertedTy = getTypeConverter()->convertType( - cast(operand.getType()).getElementType()); - if (!convertedTy) - return rewriter.notifyMatchFailure(op, - "operand type conversion failed"); - - signatureConverter.addInputs(idx, convertedTy); - } - signatureConverter.addInputs(resultType.getElementType()); - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct MapOpToMapConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::MapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) - return failure(); - - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - assert(op.getDimensions().size() == resultType.getRank() && - "Expected a pointwise map"); - - Location loc = op.getLoc(); - Value operand0 = adaptor.getOperands()[0]; - SmallVector coercedOperands = {operand0}; - for (Value operand : llvm::drop_begin(adaptor.getOperands(), 1)) { - coercedOperands.push_back(coerceTensorShape( - rewriter, loc, cast>(operand), - cast(operand0.getType()))); - } - Value output = rewriter.create( - loc, tensor::getMixedSizes(rewriter, loc, operand0), - resultType.getElementType()); - - auto linalgOp = rewriter.create( - loc, coercedOperands, output, - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. We scalarize the operands and add a - // scalar operand representing the output tensor. - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getComputation(), region, region.end()); - TypeConverter::SignatureConversion signatureConverter(op.getNumOperands()); - - for (auto [idx, operand] : llvm::enumerate(op.getOperands())) { - Type convertedTy = getTypeConverter()->convertType( - cast(operand.getType()).getElementType()); - if (!convertedTy) - return rewriter.notifyMatchFailure(op, - "operand type conversion failed"); - signatureConverter.addInputs(idx, convertedTy); - } - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - auto result = rewriter.createOrFold(loc, resultType, - linalgOp.getResults()); - rewriter.replaceOp(op, result); - return success(); - } -}; - -/// This lowering encompasses the full range of the Gather operation and -/// therefore is very general and just loops over the output and calculate the -/// corresponding input index. It follows the explanation at -/// https://www.tensorflow.org/xla/operation_semantics#gather. The compiler -/// should be able to optimize that a bit, but in order to get efficient -/// lowerings, special-cases of gather should be extracted in separate -/// lowerings, and ideally encapsulated as separate ops or canonicalization -/// patterns. -struct GatherConversion final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::GatherOp gatherOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = gatherOp.getLoc(); - - Value startIndices = adaptor.getStartIndices(); - Value operand = adaptor.getOperand(); - - auto resultType = - getTypeConverter()->convertType(gatherOp.getType()); - RankedTensorType startIndicesType = - dyn_cast(startIndices.getType()); - // We could actually deal with an unranked result by inferring the result - // rank, but the current reifyReturnTypes doesn't support unranked either. - if (!resultType || !startIndicesType) { - return rewriter.notifyMatchFailure(gatherOp, - "unranked start indices or result"); - } - - int64_t resultRank = resultType.getRank(); - // slice_sizes has to have the same size as operand.rank, and doing it this - // way permits an unranked operand. - int64_t operandRank = gatherOp.getSliceSizes().size(); - - int64_t indexVectorDim = gatherOp.getDimensionNumbers().getIndexVectorDim(); - - ArrayRef offsetDims = - gatherOp.getDimensionNumbers().getOffsetDims(); - ArrayRef collapsedSliceDims = - gatherOp.getDimensionNumbers().getCollapsedSliceDims(); - ArrayRef startIndexMap = - gatherOp.getDimensionNumbers().getStartIndexMap(); - - // We'll need these later and creating them on demand we end up with - // duplicates, which also makes lit tests really hard to write. - SmallVector constants; - for (int64_t i = 0, e = std::max({resultRank, operandRank, int64_t{2}}); - i < e; ++i) { - constants.push_back( - rewriter.create(loc, rewriter.getIndexAttr(i))); - } - - Value emptyOp = getEmptyTensorFor(rewriter, loc, resultType, gatherOp, - adaptor.getOperands()); - - ValueRange ins; - SmallVector indexingMaps( - {rewriter.getMultiDimIdentityMap(resultRank)}); - auto linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/resultType, - /*inputs=*/ins, - /*outputs=*/emptyOp, indexingMaps, getNParallelLoopsAttrs(resultRank), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(gatherOp)); - - // Now populate the linalg generic region - Region ®ion = linalgOp.getRegion(); - Block *block = rewriter.createBlock(®ion, region.end()); - block->addArguments(resultType.getElementType(), loc); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(block); - - // Dimensions in the result that aren't offset dimensions are called batch. - SmallVector batchDims; - for (int64_t dim = 0; dim < resultRank; ++dim) { - if (!llvm::is_contained(offsetDims, dim)) { - batchDims.push_back(dim); - } - } - - // Same as with the constants. Creating these all up front is easier than - // potentially getting duplicates later. - SmallVector linalgIndices; - for (int64_t i = 0; i < resultRank; ++i) { - linalgIndices.push_back(rewriter.create(loc, i)); - } - - // Now the complicated part. For a given output dimension we build up an - // index into the input. It's composed of two parts: the index coming from - // start_indices, and the offset from that index along the offset - // dimensions. Everything includes dimension shuffling and remapping as well - // because of the way gather is defined to allow for any-layout input by - // adding more attributes. - - // The base gather index (`G` in the documentation) points to a place in - // start_indices along the batch dimensions. - SmallVector gatherIndex; - for (int64_t dim : batchDims) { - gatherIndex.push_back(linalgIndices[dim]); - } - - SmallVector indexFromStartIndices; - for (size_t i = 0, e = startIndexMap.size(); i != e; ++i) { - // The index along the index_vector dimension of start_indices varies. - // Basically indexFromStartIndices indexes into a "row" along - // index_vector_dim, where the row is selected by the current output - // index. - // But if index_vector_dim is equal to start_indices.rank, then - // start_indices gets a trailing 1 dimension added. So the row we're - // extracting always has length 1 and the index into it is always 0, so we - // just use the gather index directly - SmallVector gCombine(gatherIndex); - if (indexVectorDim != startIndicesType.getRank()) { - assert(indexVectorDim <= static_cast(gCombine.size())); - gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]); - } - - indexFromStartIndices.push_back(extractIndexFromTensor( - rewriter, loc, startIndices, gatherOp.getStartIndices().getType(), - gCombine)); - } - - // But then start indices are shuffled by the start index map. To make a - // full index into the operand, all missing indices are zeroes. - SmallVector remappedIndexFromIndices(operandRank, constants[0]); - for (auto [idx, value] : llvm::enumerate(startIndexMap)) { - remappedIndexFromIndices[value] = indexFromStartIndices[idx]; - } - - // Now we construct the index based on the offset. First we need to remap - // the offset dimensions by dropping the collapsed indices. - SmallVector remappedOffsetDims; - for (int64_t i = 0; i < operandRank; ++i) { - if (!llvm::is_contained(collapsedSliceDims, i)) { - remappedOffsetDims.push_back(static_cast(i)); - } - } - - assert(remappedOffsetDims.size() == offsetDims.size()); - - // Clamp out of bounds indices. - for (int i = 0, operandIndexDim = 0; i < operandRank; ++i) { - // Compute the size of the output shape dimension corresponding to this - // index dimension. If it's collapsed set it to 1. - Value outputDimSize = constants[1]; - if (!llvm::is_contained(collapsedSliceDims, i)) { - outputDimSize = rewriter.createOrFold( - loc, emptyOp, offsetDims[operandIndexDim++]); - } - - // If this is a skipped dimension, we're done and don't have to clamp. - if (remappedIndexFromIndices[i] == constants[0]) - continue; - - Value operandDimSize = - rewriter.createOrFold(loc, operand, i); - Value largestValidIndex = rewriter.createOrFold( - loc, operandDimSize, outputDimSize); - - // Clamp indices to [0, i, operand_dim-output_dim]. - Value clamp = rewriter.create( - loc, - rewriter.create(loc, constants[0], - remappedIndexFromIndices[i]), - largestValidIndex); - remappedIndexFromIndices[i] = clamp; - } - - // For the (remapped) offset dimensions, the index is the current index in - // the output. As before this is expanded to a full index into the operand - // by using zeros for the missing indices. - SmallVector indexFromOffset(operandRank, constants[0]); - for (auto [remappedOffsetDim, offsetDim] : - llvm::zip_equal(remappedOffsetDims, offsetDims)) { - indexFromOffset[remappedOffsetDim] = linalgIndices[offsetDim]; - } - - // Now we add together our two indices to get the final index into the - // operand. - SmallVector combinedIndex; - for (int64_t i = 0; i < operandRank; ++i) - combinedIndex.push_back(rewriter.createOrFold( - loc, rewriter.getIndexType(), remappedIndexFromIndices[i], - indexFromOffset[i])); - - Value extractOperand; - if (isa(operand.getType())) { - extractOperand = operand; - } else { - // Cannot extract from unranked tensors, cast to ranked first. - SmallVector dims(operandRank, ShapedType::kDynamic); - auto type = RankedTensorType::get( - dims, cast(operand.getType()).getElementType()); - extractOperand = rewriter.create(loc, type, operand); - } - Value element = - rewriter.create(loc, extractOperand, combinedIndex); - rewriter.create(loc, element); - - rewriter.replaceOp(gatherOp, linalgOp.getResults()); - - return success(); - } -}; - -/// Converts xla-hlo.select_and_scatter op to a sequence of linalg.generics ops. -/// The current version computes the scattered index and populates the correct -/// value for each tile. It does not currently handle overlapping tiles. -struct SelectAndScatterNoOverlapConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::SelectAndScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - ImplicitLocOpBuilder b(loc, rewriter); - Value source = op.getSource(); - Value operand = op.getOperand(); - Value init = op.getInitValue(); - - auto sourceTy = llvm::dyn_cast(source.getType()); - auto operandTy = llvm::dyn_cast(operand.getType()); - auto initTy = llvm::dyn_cast(init.getType()); - auto resultTy = llvm::dyn_cast(op.getResult().getType()); - if (!sourceTy || !operandTy || !initTy || !resultTy) - return rewriter.notifyMatchFailure(op, "inputs/outputs must be ranked"); - - auto indexETy = b.getI32Type(); - auto srcETy = operandTy.getElementType(); - auto destETy = initTy.getElementType(); - - const int64_t rank = sourceTy.getRank(); - - llvm::SmallVector pad(rank * 2, 0); - if (op.getPadding().has_value()) - pad = llvm::to_vector(op.getPaddingAttr().getValues()); - - // TODO(suderman): Add support for padding. - if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) - return rewriter.notifyMatchFailure(op, "non-zero padding values found."); - - if (!op.getWindowStrides().has_value()) - return rewriter.notifyMatchFailure(op, "no window strides found"); - - if (!op.getWindowDimensions().has_value()) - return rewriter.notifyMatchFailure(op, "no window dimensions found"); - - auto strides = llvm::to_vector(op.getWindowStrides().value()); - auto window = llvm::to_vector(op.getWindowDimensions().value()); - - if (static_cast(strides.size()) != operandTy.getRank() || - static_cast(window.size()) != operandTy.getRank()) - return rewriter.notifyMatchFailure( - op, "stride/window length should equal operand rank"); - - // The current version cannot handle overlapped regions. - for (int i = 0, s = strides.size(); i < s; ++i) { - if (strides[i] < window[i]) - return rewriter.notifyMatchFailure( - op, "overlapping windows are not supported"); - } - - // If the window only contains a single element, this lowering will be - // problematic. Ultimately we should handle this with a canonicalizer. - if (llvm::all_of(window, [](auto sz) { return sz == 1; })) { - return rewriter.notifyMatchFailure(op, - "unary window size is not supported"); - } - - // The first linalg.generic operation computes the relevant index over - // window for the defined stablehlo.select_and_scatter. This involves - // iterating over the window of the operand a computing the index. - // Rather than storing N indices we compute the row major identifier - // in the window, to specify which location should be scattered to. - - // Output specifies the `rank` parallel iterators that correspond to - // output values. - SmallVector outputExprs; - for (int i = 0, s = rank; i < s; ++i) - outputExprs.push_back(b.getAffineDimExpr(i)); - - // For the output we need to define the reduction across the window - // width and height. This includes applying striding behavior and - // adding the additional reduction iterators. We skip length-1 dimensions - // as the reduction is degenerate. - SmallVector filteredWindows, filteredStrides; - SmallVector sourceExprs(outputExprs); - SmallVector windowExprs; - for (int i = 0, s = rank; i < s; ++i) { - sourceExprs[i] = sourceExprs[i] * strides[i]; - if (strides[i] != 1) { - auto expr = b.getAffineDimExpr(windowExprs.size() + sourceExprs.size()); - sourceExprs[i] = sourceExprs[i] + expr; - windowExprs.push_back(expr); - filteredWindows.push_back(window[i]); - filteredStrides.push_back(strides[i]); - } - } - - // Determine the total number of AffineExprs and construct the IndexingMaps - // for the windowed reduction operation. - const int64_t reduceExprCount = windowExprs.size() + sourceExprs.size(); - SmallVector reduceIndexingMaps; - reduceIndexingMaps.push_back(AffineMap::get(reduceExprCount, - /*symbolCount=*/0, sourceExprs, - rewriter.getContext())); - reduceIndexingMaps.push_back(AffineMap::get(reduceExprCount, - /*symbolCount=*/0, windowExprs, - rewriter.getContext())); - auto reduceOutMap = - AffineMap::get(reduceExprCount, - /*symbolCount=*/0, outputExprs, rewriter.getContext()); - reduceIndexingMaps.push_back(reduceOutMap); - reduceIndexingMaps.push_back(reduceOutMap); - - // Output sizes should match the dimensions of the `source` tensor, even if - // dynamic. - SmallVector reduceDynSizes; - for (int i = 0, s = rank; i < s; ++i) - if (sourceTy.isDynamicDim(i)) - reduceDynSizes.push_back(b.create(source, i)); - - Value reduceValueEmpty = - b.create(sourceTy.getShape(), destETy, reduceDynSizes); - Value reduceIndexEmpty = b.create( - sourceTy.getShape(), indexETy, reduceDynSizes); - - // We initialize indices to -1 which indicates no matching destination. - Value negativeOne = b.create(b.getI32IntegerAttr(-1)); - reduceIndexEmpty = - b.create(negativeOne, reduceIndexEmpty).getResult(0); - - // We only care to match the reduction dimensions. - Value windowEmpty = b.create(filteredWindows, srcETy); - - auto reduceGeneric = b.create( - /*resultTensors=*/ArrayRef{reduceValueEmpty.getType(), - reduceIndexEmpty.getType()}, - /*inputs=*/ValueRange{operand, windowEmpty}, - /*outputs=*/ValueRange{reduceValueEmpty, reduceIndexEmpty}, - reduceIndexingMaps, - getParallelAndReductionIterators(reduceExprCount, windowExprs.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // First we clone in the selection block. - auto &reduceRegion = reduceGeneric.getRegion(); - rewriter.setInsertionPoint(reduceGeneric); - rewriter.cloneRegionBefore(op.getSelect(), reduceRegion, - reduceRegion.end()); - - // This includes convert `stablehlo` scalar-tensor regions to `linalg` - // scalars. - TypeConverter::SignatureConversion reduceSignConverter(4); - reduceSignConverter.addInputs(0, srcETy); - reduceSignConverter.addInputs(srcETy); - reduceSignConverter.addInputs(1, destETy); - reduceSignConverter.addInputs(indexETy); - rewriter.applySignatureConversion(&reduceRegion.front(), - reduceSignConverter, getTypeConverter()); - - // Grab the terminator and use the turned value to now select the - // correct index and value. - auto &reduceBlock = reduceRegion.front(); - auto *reduceTerminator = reduceBlock.getTerminator(); - Value selectPred = reduceTerminator->getOperand(0); - Value selectInVal = reduceBlock.getArgument(0); - Value selectOutVal = reduceBlock.getArgument(2); - Value selectOutIdx = reduceBlock.getArgument(3); - - b.setInsertionPoint(reduceTerminator); - - // The predicate operates on scalar-tensors, so we need to extract the - // value for `linalg` operations. Tensor-ops are cleaned up by other - // rewriters. - selectPred = b.create(rewriter.getI1Type(), selectPred, - ValueRange{}); - - // We select if either the selection function returns `true` or the - // current reduction index is `-1`, e.g. no index has been selected yet. - Value selectNegOne = b.create(arith::CmpIPredicate::eq, - selectOutIdx, negativeOne); - selectPred = b.create(selectPred, selectNegOne); - - // We compute a unique idx for each element in the window. - Value computedIdx = b.create(rank); - for (int i = 1, s = filteredStrides.size(); i < s; ++i) { - Value width = b.create(filteredStrides[i]); - Value idx = b.create(rank + i); - computedIdx = b.create(width, computedIdx); - computedIdx = b.create(computedIdx, idx); - } - computedIdx = b.create(indexETy, computedIdx); - - // Using the selection predicate track the value and selected - // identifier for the future scattering. - Value selectedIdx = - b.create(selectPred, computedIdx, selectOutIdx); - Value selectedValue = - b.create(selectPred, selectInVal, selectOutVal); - b.create(ValueRange{selectedValue, selectedIdx}); - - // Original terminator is an stablehlo.return we no longer need. - rewriter.eraseOp(reduceTerminator); - b.setInsertionPoint(op); - - Value reduceIndex = reduceGeneric.getResult(1); - ShapedType reduceIndexTy = llvm::cast(reduceIndex.getType()); - - // For the second generic we restricted to only cases where there are - // no window overlaps. This guarantees that each source value is scattered - // within its own unique window. We can broadcast to this window size and - // populate only the relative location. - llvm::SmallVector broadcastShape; - llvm::SmallVector broadcastDynDims; - llvm::SmallVector broadcastExprs; - for (int i = 0, s = reduceIndexTy.getRank(); i < s; ++i) { - int64_t broadcast = strides[i]; - if (sourceTy.isDynamicDim(i)) - broadcastDynDims.push_back(b.create(source, i)); - - broadcastExprs.push_back(b.getAffineDimExpr(broadcastShape.size())); - broadcastShape.push_back(sourceTy.getDimSize(i)); - if (broadcast > 1) { - broadcastShape.push_back(broadcast); - } - } - - // We broadcast the values of our input tensors across the stride-tiling - // size. - Value scatterEmpty = b.create( - broadcastShape, resultTy.getElementType(), broadcastDynDims); - Value initScalar = b.create(initTy.getElementType(), - init, ValueRange{}); - Value scatterFill = - b.create(initScalar, scatterEmpty).getResult(0); - - // Both the indices and values are broadcasted using the same indexing map. - // Output fully parallel. - auto scatterInputMap = - AffineMap::get(broadcastShape.size(), /*symbolCount=*/0, broadcastExprs, - b.getContext()); - SmallVector scatterIndexingMaps = { - scatterInputMap, scatterInputMap, - b.getMultiDimIdentityMap(broadcastShape.size())}; - - auto scatterGeneric = b.create( - /*resultTensors=*/ArrayRef{scatterFill.getType()}, - /*inputs=*/ValueRange{reduceIndex, source}, - /*outputs=*/ValueRange{scatterFill}, scatterIndexingMaps, - getNParallelLoopsAttrs(broadcastShape.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Clone the scattering combination logic and perform the tensor-to-scalar - // conversion. - auto &scatterRegion = scatterGeneric.getRegion(); - b.setInsertionPoint(scatterGeneric); - rewriter.cloneRegionBefore(op.getScatter(), scatterRegion, - scatterRegion.end()); - - TypeConverter::SignatureConversion scatterSignConverter(4); - scatterSignConverter.addInputs(indexETy); - scatterSignConverter.addInputs(0, sourceTy.getElementType()); - scatterSignConverter.addInputs(1, sourceTy.getElementType()); - rewriter.applySignatureConversion(&scatterRegion.front(), - scatterSignConverter, getTypeConverter()); - - auto &scatterBlock = scatterRegion.front(); - auto scatterTerminator = scatterBlock.getTerminator(); - b.setInsertionPoint(scatterTerminator); - - Value scatterInputIdx = scatterBlock.getArgument(0); - Value scatterOutputVal = scatterBlock.getArgument(2); - Value scatterUpdate = b.create( - sourceTy.getElementType(), scatterTerminator->getOperand(0), - ValueRange{}); - - // Compute the index of the tiled region to determine if it was selected. - Value id = b.create(0); - int64_t dim = 0; - for (int i = 0, s = strides.size(); i < s; ++i) { - if (strides[i] > 1) { - Value idx = b.create(++dim); - Value tileSz = b.create(strides[i]); - id = b.create(id, tileSz); - id = b.create(id, idx); - } - ++dim; - } - - // Check whether the computed id matches the to-scatter id, then select and - // yield. - id = b.create(indexETy, id); - auto scatterPred = b.create( - b.getI1Type(), arith::CmpIPredicate::eq, id, scatterInputIdx); - scatterUpdate = - b.create(scatterPred, scatterUpdate, scatterOutputVal); - - b.create(scatterUpdate); - rewriter.eraseOp(scatterTerminator); - b.setInsertionPoint(op); - - // We now need to collapse the tiles back into their - // source dimensions. We collapse any of the broadcast regions together. - int64_t collapseDim = 0; - SmallVector reassociationMap; - for (int i = 0, s = window.size(); i < s; ++i) { - SmallVector dims = {collapseDim}; - if (strides[i] > 1) - dims.push_back(collapseDim + 1); - - reassociationMap.push_back(ReassociationIndices(dims)); - collapseDim += dims.size(); - } - - Value collapse = b.create( - scatterGeneric.getResult(0), reassociationMap); - auto collapseTy = llvm::cast(collapse.getType()); - - // After collapsing it it possible that the target may need to be padded. - auto zero = b.createOrFold(0); - SmallVector padShape; - SmallVector padLow, padHigh; - padLow.resize(operandTy.getRank(), zero); - - for (int i = 0, s = rank; i < s; ++i) { - int64_t size = std::max(resultTy.getDimSize(i), collapseTy.getDimSize(i)); - if (operandTy.isDynamicDim(i) || collapseTy.isDynamicDim(i)) - size = ShapedType::kDynamic; - padShape.push_back(size); - - Value in = b.create(collapse, i); - Value out = b.create(operand, i); - Value diff = b.create(out, in); - Value pad = b.createOrFold(diff, zero); - padHigh.push_back(pad); - } - - Value padded = b.create(collapseTy.clone(padShape), collapse, - padLow, padHigh, initScalar); - - // The result may exceed the target size, slice if necessary. - SmallVector sliceSizes; - SmallVector sliceOffsets(operandTy.getRank(), - b.getIndexAttr(0)); - SmallVector sliceStrides(operandTy.getRank(), - b.getIndexAttr(1)); - for (int i = 0, s = operandTy.getRank(); i < s; ++i) { - OpFoldResult dim = b.getIndexAttr(operandTy.getDimSize(i)); - if (operandTy.isDynamicDim(i)) - dim = b.createOrFold(operand, i); - sliceSizes.push_back(dim); - } - - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, padded, sliceOffsets, sliceSizes, sliceStrides); - - return success(); - } -}; - -// Decomposes a pad with negative edge padding into a pad without negative edge -// padding and a tensor.extract_slice. -struct PadOpNegativePaddingConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::PadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector padLow; - SmallVector padHigh; - SmallVector sliceStarts; - - bool hasNegativePadding = false; - for (int64_t low : op.getEdgePaddingLow()) { - if (low >= 0) { - padLow.push_back(low); - sliceStarts.push_back(rewriter.getIndexAttr(0)); - } else { - padLow.push_back(0); - sliceStarts.push_back(rewriter.getIndexAttr(-low)); - hasNegativePadding = true; - } - } - - for (int64_t high : op.getEdgePaddingHigh()) { - if (high >= 0) { - padHigh.push_back(high); - } else { - padHigh.push_back(-high); - hasNegativePadding = true; - } - } - - // If there's no negative edge padding we're done. - if (!hasNegativePadding) - return failure(); - - // Create a new pad op with the positive values. - Value pad = rewriter.create( - op.getLoc(), adaptor.getOperand(), adaptor.getPaddingValue(), - rewriter.getDenseI64ArrayAttr(padLow), - rewriter.getDenseI64ArrayAttr(padHigh), op.getInteriorPadding()); - - // Then slice according to the negative edge padding. Static shapes only for - // now. - if (!op.getType().hasStaticShape()) - return failure(); - SmallVector sizes( - llvm::map_range(op.getType().getShape(), [&](int64_t dim) { - return rewriter.getIndexAttr(dim); - })); - SmallVector strides(sliceStarts.size(), - rewriter.getIndexAttr(1)); - rewriter.replaceOpWithNewOp(op, pad, sliceStarts, - sizes, strides); - return success(); - } -}; - -/// Converts stablehlo.pad operation to tensor.pad or tensor.insert_slice. -struct PadOpConversion final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::PadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto resultType = - getTypeConverter()->convertType(op.getResult().getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - // Negative edge padding is decomposed separately. - auto isNegative = [](int64_t intVal) { return intVal < 0; }; - if (llvm::any_of(op.getEdgePaddingLow(), isNegative) || - llvm::any_of(op.getEdgePaddingHigh(), isNegative)) - return failure(); - - Value paddingVal = rewriter.createOrFold( - loc, adaptor.getPaddingValue()); - - auto i64ToFoldResult = [&](const int64_t &i) -> OpFoldResult { - return rewriter.getIntegerAttr(rewriter.getI64Type(), i); - }; - - // If there is no interior padding lower to tensor.pad directly. - if (llvm::all_of(op.getInteriorPadding(), - [](const int64_t &i) { return i == 0; })) { - auto padTensorOp = rewriter.create( - loc, resultType, adaptor.getOperand(), - llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult), - llvm::map_to_vector(op.getEdgePaddingHigh(), i64ToFoldResult), - paddingVal); - rewriter.replaceOp(op, padTensorOp.getResult()); - return success(); - } - - // We have interior padding, which can be lowered to tensor.insert_slice. - // Start by filling a result-sized tensor with the pad value. - auto emptyTensor = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); - auto fill = - rewriter.create(loc, paddingVal, emptyTensor).result(); - - // Get sizes of the original operand. - auto operandType = llvm::cast(adaptor.getOperand().getType()); - auto sizes = llvm::map_to_vector( - llvm::seq(0, operandType.getRank()), - [&](int64_t dim) -> OpFoldResult { - if (!operandType.isDynamicDim(dim)) - return rewriter.getIndexAttr(operandType.getDimSize(dim)); - return rewriter.create(loc, adaptor.getOperand(), dim) - .getResult(); - }); - // Map interior padding to strides. - auto strides = llvm::map_to_vector( - op.getInteriorPadding(), [&](const int64_t &stride) -> OpFoldResult { - return rewriter.getIntegerAttr(rewriter.getI64Type(), stride + 1); - }); - - rewriter.replaceOpWithNewOp( - op, adaptor.getOperand(), fill, - llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult), sizes, - strides); - return success(); - } -}; - -/// Converts xla-hlo.torch_index_select op to a linalg.generic op. -struct TorchIndexSelectOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::TorchIndexSelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - int axis = static_cast(op.getDim()); - int batch = static_cast(op.getBatchDims()); - auto indexShapedType = llvm::cast(adaptor.getIndex().getType()); - int numIndices = static_cast(indexShapedType.getRank()); - auto operandShapedType = - llvm::cast(adaptor.getOperand().getType()); - if (axis < 0) - axis += static_cast(operandShapedType.getRank()); - if (batch < 0) - batch += numIndices; - - Location loc = op.getLoc(); - auto resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(op, "type conversion failed"); - - int rank = static_cast(resultType.getRank()); - - // The output shape is - // `params[:axis] + indices[batch_dims:] + params[axis + 1:]` - SmallVector dynSizes; - for (int i = 0; i < rank; ++i) { - if (!resultType.isDynamicDim(i)) - continue; - if (i < axis) { - dynSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), i)); - } else if (i < (axis + numIndices - batch)) { - int idx = i - axis + batch; - dynSizes.push_back( - rewriter.create(loc, adaptor.getIndex(), idx)); - } else { - int idx = i - (axis + numIndices - batch) + axis + 1; - dynSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), idx)); - } - } - - // Generate dummy tensor to preserve slice shape information. - SmallVector sliceShape; - SmallVector dynSliceSizes; - SmallVector sliceExprs; - ArrayRef resultShape = resultType.getShape(); - for (int i = 0; i < axis; ++i) { - sliceExprs.push_back(rewriter.getAffineDimExpr(i)); - sliceShape.push_back(resultShape[i]); - if (!resultType.isDynamicDim(i)) - continue; - dynSliceSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), i)); - } - for (int i = axis + numIndices - batch; i < rank; ++i) { - sliceExprs.push_back(rewriter.getAffineDimExpr(i)); - sliceShape.push_back(resultShape[i]); - if (!resultType.isDynamicDim(i)) - continue; - int idx = i - (axis + numIndices - batch) + axis + 1; - dynSliceSizes.push_back( - rewriter.create(loc, adaptor.getOperand(), idx)); - } - - // Setup AffineMap for operand tensor. - SmallVector exprs; - for (int i = 0; i < batch; ++i) { - exprs.push_back(rewriter.getAffineDimExpr(i)); - } - for (int i = 0, e = numIndices - batch; i < e; ++i) { - exprs.push_back(rewriter.getAffineDimExpr(axis + i)); - } - - SmallVector indexingMaps; - indexingMaps.emplace_back( - AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext())); - indexingMaps.emplace_back(AffineMap::get( - rank, /*symbolCount=*/0, sliceExprs, rewriter.getContext())); - indexingMaps.emplace_back(rewriter.getMultiDimIdentityMap(rank)); - - Value sliceOp = rewriter.create( - loc, sliceShape, resultType.getElementType(), dynSliceSizes); - - Value emptyOp = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), dynSizes); - auto linalgOp = rewriter.create( - loc, /*resultTensors=*/ArrayRef{resultType}, - /*inputs=*/ValueRange{adaptor.getIndex(), sliceOp}, - /*outputs=*/emptyOp, indexingMaps, getNParallelLoopsAttrs(rank), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - SmallVector bodyArgTypes; - SmallVector linalgOpArgs = {adaptor.getIndex(), sliceOp}; - // Add a block to the region. - auto *region = &linalgOp.getRegion(); - auto *block = rewriter.createBlock(region, region->end()); - for (auto blockArgs : linalgOpArgs) { - bodyArgTypes.push_back( - llvm::cast(blockArgs.getType()).getElementType()); - } - block->addArguments(bodyArgTypes, - SmallVector(bodyArgTypes.size(), loc)); - block->addArguments(resultType.getElementType(), loc); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(block); - - Value castedValue = rewriter.create( - loc, rewriter.getIndexType(), block->getArgument(0)); - - SmallVector indices; - for (int i = 0; i < axis; ++i) { - indices.push_back(rewriter.create(loc, i)); - } - indices.push_back(castedValue); - for (int i = axis + numIndices - batch; i < rank; ++i) { - indices.push_back(rewriter.create(loc, i)); - } - Value res = - rewriter.create(loc, adaptor.getOperand(), indices); - rewriter.create(loc, res); - - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct SetDimensionSizeConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::SetDimensionSizeOp setDimensionSizeOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // We can lower SetDimensionSize to tensor extract. This turns into a - // regular dynamic shape. Note that the bounds annotation is still around - // but may be no longer valid depending on choices made by bufferization. - Location loc = setDimensionSizeOp.getLoc(); - auto resultType = dyn_cast(setDimensionSizeOp.getType()); - if (!resultType) - return rewriter.notifyMatchFailure(setDimensionSizeOp, - "expected a ranked tensor"); - - SmallVector offsets(resultType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector strides(resultType.getRank(), - rewriter.getIndexAttr(1)); - SmallVector sizes(llvm::map_range( - resultType.getShape(), [&](int64_t dim) -> OpFoldResult { - return rewriter.getIndexAttr(dim); - })); - Value dimensionSize = - rewriter.create(loc, setDimensionSizeOp.getSize()); - sizes[setDimensionSizeOp.getDimension()] = - rewriter - .create(loc, rewriter.getIndexType(), - dimensionSize) - .getResult(); - - rewriter.replaceOpWithNewOp( - setDimensionSizeOp, resultType, adaptor.getOperand(), offsets, sizes, - strides); - return success(); - } -}; - -struct ConvertStableHloToLinalg final - : impl::ConvertStableHloToLinalgBase { - using ConvertStableHloToLinalgBase::ConvertStableHloToLinalgBase; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext &ctx = getContext(); - RewritePatternSet patterns(&ctx); - ConversionTarget target(ctx); - target.addLegalDialect< - bufferization::BufferizationDialect, arith::ArithDialect, - complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, - tensor::TensorDialect, sparse_tensor::SparseTensorDialect, - scf::SCFDialect, shape::ShapeDialect>(); - - target.addLegalOp(); - - auto typeConverter = createStableHloToLinalgTypeConverter(); - ModuleOp module = getOperation(); - - populateStableHloToLinalgConversionPatterns(&ctx, *typeConverter, &patterns, - this->enablePrimitiveOps); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -void populateStableHloToLinalgConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet *patterns, - bool enablePrimitiveOps) { - // clang-format off - patterns->add< - BitcastConvertConverter, - ConcatenateConverter, - ConstConverterTensor, - EinsumToLinalgConverter, - GatherConversion, - RealDynamicSliceConverter, - ReshapeOpConverter, - ReverseConverter, - SetDimensionSizeConverter, - SliceConverter, - DynamicSliceConverter, - DynamicUpdateSliceConverter, - PadOpConversion, - PadOpNegativePaddingConversion, - TorchIndexSelectOpConversion, - SelectAndScatterNoOverlapConverter - >(typeConverter, context); - - detail::populatePointwiseStableHloToLinalgConversionPatterns( - context, typeConverter, patterns, enablePrimitiveOps); - - if (enablePrimitiveOps) { - patterns->add< - BroadcastInDimOpToBroadcastConverter, - BroadcastOpToBroadcastConverter, - DynamicBroadcastInDimOpToBroadcastConverter, - IotaToMapConverter, - IotaToMapConverter, - MapOpToMapConverter, - TransposeOpToTransposeConverter - >(typeConverter, context); - } else { - patterns->add< - BroadcastConverter, - IotaConverter, - IotaConverter, - HloBroadcastInDimConverter, - HloDynamicBroadcastInDimConverter, - MapOpToGenericConverter, - TransposeConverter - >(typeConverter, context); - } - - // clang-format on - - detail::populateStableHloConvolutionToLinalgConversionPatterns( - context, typeConverter, patterns); - detail::populateStableHloDotProdToLinalgConversionPatterns( - context, typeConverter, patterns); - detail::populateStableHloRandomToLinalgConversionPatterns( - context, typeConverter, patterns); - detail::populateStableHloReductionToLinalgConversionPatterns( - context, typeConverter, patterns, enablePrimitiveOps); - detail::populateScalarHloToArithConversionPatterns( - context, typeConverter, patterns, isInBodyOfLinalgOps); - linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns); -} - -std::unique_ptr createStableHloToLinalgTypeConverter() { - return std::make_unique(); -} - -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp deleted file mode 100644 index f51fcd66399a..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgConvolution.cpp +++ /dev/null @@ -1,807 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO convolution ops to Linalg dialect. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -/// Apply dilation and padding to the input of a convolution. -Value applyConvolutionPadding(Location loc, Value input, - DenseIntElementsAttr padding, - std::optional> lhsDilation, - llvm::ArrayRef dimMappings, - OpBuilder &rewriter) { - SmallVector lhsDilationValues; - if (lhsDilation.has_value()) - lhsDilationValues = llvm::to_vector(lhsDilation.value()); - bool noPadding = !padding || isSplatValue(padding, 0); - bool noDilation = !lhsDilation || hlo::isSplatArray(lhsDilationValues, 1); - if (noPadding && noDilation) - return input; - - auto inputType = cast(input.getType()); - int64_t rank = inputType.getRank(); - - // Translate window padding into low/high padding. - SmallVector padLow(rank, 0); - SmallVector padHigh(rank, 0); - if (padding) { - // The padding attribute contains two values per dimension, but excludes the - // batch and feature dimensions. - assert(rank * 2 == padding.size() + 4 && - "There should be 2 padding values per dimension, i.e low and high."); - for (int64_t i : llvm::seq(0, padding.size() / 2)) { - int64_t dim = dimMappings[i]; - padLow[dim] = padding.getValues()[i * 2]; - padHigh[dim] = padding.getValues()[i * 2 + 1]; - } - } - - // Translate input dilation into interior padding. - SmallVector padInterior(rank, 0); - if (lhsDilation) { - assert(rank == static_cast(lhsDilationValues.size()) + 2); - for (int64_t i : llvm::seq(0, lhsDilationValues.size())) { - int64_t dim = dimMappings[i]; - padInterior[dim] = lhsDilationValues[i] - 1; - } - } - - Value zero; - if (auto complexType = dyn_cast(inputType.getElementType())) { - auto zeroElement = rewriter.getZeroAttr(complexType.getElementType()); - auto zeroAttr = rewriter.getArrayAttr({zeroElement, zeroElement}); - zero = rewriter.create(loc, complexType, zeroAttr); - zero = rewriter.create( - loc, RankedTensorType::get({}, complexType), zero); - } else { - zero = rewriter.create( - loc, rewriter.getZeroAttr( - RankedTensorType::get({}, inputType.getElementType()))); - } - - return rewriter.create(loc, input, zero, padLow, - padHigh, padInterior); -} - -/// If the ConvolutionOp has a window reversal, applies it to the filter. -Value applyConvolutionReversal(Location loc, OpBuilder &b, - mlir::stablehlo::ConvolutionOp op, - Value filter) { - std::optional reversals = op.getWindowReversal(); - if (!reversals.has_value()) { - return filter; - } - llvm::SmallVector reversedDims; - for (auto [idx, reversed] : llvm::enumerate(reversals.value())) { - if (reversed) { - reversedDims.push_back( - op.getDimensionNumbers().getKernelSpatialDimensions()[idx]); - } - } - - return b.create( - loc, filter, b.getDenseI64ArrayAttr(reversedDims)); -} - -/// Returns true if the given `dimensionNumbers` from a stablehlo.convolution op -/// follows a canonical form: -/// -/// * Input dimensions have order: (batch_count, spatial_dims, -/// input_channel_count). -/// * Filter dimensions have order: (spatial_dims, input_channel_count, -/// output_channel_count). -/// * Output dimensions have order: (batch_count, spatial_dims, -/// output_channel_count). -bool hasCanonicalDimensionNumbers( - mlir::stablehlo::ConvDimensionNumbersAttr dimensionNumbers) { - const int64_t inputSpatialRank = - dimensionNumbers.getInputSpatialDimensions().size(); - // The dimensions for input should follow the order of - // batch_count, spatial_dims..., input_feature_count. - if (dimensionNumbers.getInputBatchDimension() != 0 || - dimensionNumbers.getInputFeatureDimension() != (inputSpatialRank + 1)) { - return false; - } - - const int64_t kernelSpatialRank = - dimensionNumbers.getKernelSpatialDimensions().size(); - // The dimensions for filter should follow the order of - // spatial_dims..., input_feature_count, num_output_feature_count. - if (dimensionNumbers.getKernelInputFeatureDimension() != kernelSpatialRank || - dimensionNumbers.getKernelOutputFeatureDimension() != - (kernelSpatialRank + 1)) { - return false; - } - - const int64_t outputSpatialRank = - dimensionNumbers.getOutputSpatialDimensions().size(); - // The dimensions for output should follow the order of - // batch_count, spatial_dims.., output_feature_count. - if (dimensionNumbers.getOutputBatchDimension() != 0 || - dimensionNumbers.getOutputFeatureDimension() != (outputSpatialRank + 1)) { - return false; - } - - if (inputSpatialRank != outputSpatialRank || - inputSpatialRank != kernelSpatialRank) { - return false; - } - - const int64_t *inputSpatialDim = - dimensionNumbers.getInputSpatialDimensions().data(); - const int64_t *kernelSpatialDim = - dimensionNumbers.getKernelSpatialDimensions().data(); - const int64_t *outputSpatialDim = - dimensionNumbers.getOutputSpatialDimensions().data(); - // Check spatial dims are ordered correctly. - for (int64_t i = 0; i < inputSpatialRank; ++i) { - const int64_t dim = i + 1; - if ((*inputSpatialDim++) != dim || (*outputSpatialDim++) != dim || - (*kernelSpatialDim++) != i) { - return false; - } - } - - return true; -} - -/// Converts stablehlo.conv operation to linalg named op. This only covers -/// normal convolution cases. The op must have canonical dimension numbers. -/// Depthwise convolution and pointwise convolution are not handled in the -/// conversion. -struct NormalConvolutionOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!hasCanonicalDimensionNumbers(op.getDimensionNumbers())) { - return failure(); - } - if (op.getFeatureGroupCount() != 1u) - return failure(); - if (op.getBatchGroupCount() != 1u) - return failure(); - - Location loc = op.getLoc(); - Value input = adaptor.getLhs(); - Value filter = adaptor.getRhs(); - filter = applyConvolutionReversal(loc, rewriter, op, filter); - auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - int64_t rank = resultType.getRank(); - - // Immediately emit an EmptyOp for output tensors with zero dimension. - if (llvm::is_contained(resultType.getShape(), 0)) { - rewriter.replaceOpWithNewOp(op, resultType.getShape(), - resultType.getElementType()); - return success(); - } - - // The output shape is N spatial_dims F. - SmallVector dynSizes; - if (resultType.isDynamicDim(0)) { - dynSizes.push_back(rewriter.create(loc, input, 0)); - } - for (int64_t i = 1, e = rank - 1; i < e; ++i) { - if (resultType.isDynamicDim(i)) { - return rewriter.notifyMatchFailure( - op, "expected output spatial dims to be static shapes"); - } - } - if (resultType.isDynamicDim(rank - 1)) { - dynSizes.push_back(rewriter.create(loc, filter, rank - 1)); - } - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), dynSizes); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - linalg::LinalgOp res; - Attribute strides; - if (auto s = op.getWindowStrides()) - strides = rewriter.getI64TensorAttr(*s); - Attribute dilations; - if (auto d = op.getRhsDilation()) - dilations = rewriter.getI64TensorAttr(*d); - - // Apply padding and input dilation. - llvm::SmallVector spatialDimMapping(rank - 2); - std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1); - input = applyConvolutionPadding(loc, input, op.getPaddingAttr(), - op.getLhsDilation(), spatialDimMapping, - rewriter); - - switch (rank) { - case 2: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - linalg::getPrunedAttributeList(op)); - break; - } - case 3: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - strides, dilations, linalg::getPrunedAttributeList(op)); - break; - } - case 4: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - strides, dilations, linalg::getPrunedAttributeList(op)); - break; - } - case 5: { - res = rewriter.create( - loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor}, - strides, dilations, linalg::getPrunedAttributeList(op)); - break; - } - default: { - return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op"); - } - } - rewriter.replaceOp(op, res.getOperation()->getResults()); - return success(); - } -}; - -/// Handles all possible inputs for the mlir::stablehlo::ConvolutionOp -struct ConvolutionOpGeneralConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - /// This lowering proceeds with the following steps: - /// 1. Handle padding and dilation of the input - /// 2. Handle padding and dilation of the window - /// 3. Handle reversal of the window - /// 4. If feature_group_count != 1: - /// - Reshape the input feature dimension, kernel output feature dimension, - /// and output feature dimension. - /// - Create the AffineExpr for the new dimension - /// - Conceptually, this splits the input feature and both output feature - /// dimensions and computes sets of convolutions with these partial views - /// of the values as if they were multiple convolutions combined in a - /// batch. - /// 5: If batch_group_count != 1: - /// - Reshape the input batch dimension, kernel output feature dimension, - /// and output feature dimension. - /// - Create the AffineExpr for the new dimension - /// - Conceptually, this splits the input batch and both output feature - /// dimensions and computes sets of convolutions with these partial views - /// of the values as if they were multiple convolutions combined in a - /// batch. - /// 6. For all dimensions not newly created by a reshape, create the - /// appropriate parallel and reduction dimensions to create a convolution. - /// 7. Create the linalg.generic that computes the multiply-add - /// 8. Reshape the output to the original shape if it was reshaped by the - /// feature or group count attributes. - LogicalResult - matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = op.getContext(); - - auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - auto reshapedResultShape = resultType.getShape().vec(); - if (!resultType.hasStaticShape()) - return failure(); - - // Immediately emit an EmptyOp for output tensors with zero dimension. - if (llvm::is_contained(reshapedResultShape, 0)) { - rewriter.replaceOpWithNewOp(op, reshapedResultShape, - resultType.getElementType()); - return success(); - } - - mlir::stablehlo::ConvDimensionNumbersAttr dimensionNumbers = - op.getDimensionNumbers(); - int64_t inputBatchDimension = dimensionNumbers.getInputBatchDimension(); - int64_t inputFeatureDimension = dimensionNumbers.getInputFeatureDimension(); - ArrayRef inputSpatialDimensions = - dimensionNumbers.getInputSpatialDimensions(); - - int64_t kernelInputFeatureDimension = - dimensionNumbers.getKernelInputFeatureDimension(); - int64_t kernelOutputFeatureDimension = - dimensionNumbers.getKernelOutputFeatureDimension(); - ArrayRef kernelSpatialDimensions = - dimensionNumbers.getKernelSpatialDimensions(); - - int64_t outputBatchDimension = dimensionNumbers.getOutputBatchDimension(); - int64_t outputFeatureDimension = - dimensionNumbers.getOutputFeatureDimension(); - ArrayRef outputSpatialDimensions = - dimensionNumbers.getOutputSpatialDimensions(); - - size_t featureGroupCount = op.getFeatureGroupCount(); - size_t batchGroupCount = op.getBatchGroupCount(); - - if (op.getFeatureGroupCount() != 1 && op.getBatchGroupCount() != 1) { - return rewriter.notifyMatchFailure( - op, "only one of feature and batch group counts can be non-one"); - } - - // Decompose the convolution into an initial padding - Value modifiedLhs = applyConvolutionPadding( - op.getLoc(), adaptor.getLhs(), adaptor.getPaddingAttr(), - adaptor.getLhsDilation(), - op.getDimensionNumbers().getInputSpatialDimensions(), rewriter); - Value modifiedRhs = applyConvolutionPadding( - op.getLoc(), adaptor.getRhs(), nullptr, adaptor.getRhsDilation(), - op.getDimensionNumbers().getKernelSpatialDimensions(), rewriter); - modifiedRhs = applyConvolutionReversal(loc, rewriter, op, modifiedRhs); - - // Non-one values for feature or batch group counts will result in reshaped - // inputs and outputs. These mappings are used to keep track of the the new - // index after reshaping has possibly inserted new dimensions. - auto paddedLhsType = cast(modifiedLhs.getType()); - auto paddedRhsType = cast(modifiedRhs.getType()); - SmallVector lhsIndexMapping(paddedLhsType.getRank()); - std::iota(lhsIndexMapping.begin(), lhsIndexMapping.end(), 0); - SmallVector rhsIndexMapping(paddedRhsType.getRank()); - std::iota(rhsIndexMapping.begin(), rhsIndexMapping.end(), 0); - SmallVector resultIndexMapping(resultType.getRank()); - std::iota(resultIndexMapping.begin(), resultIndexMapping.end(), 0); - auto updateDimMappingFromOffset = - [](llvm::SmallVectorImpl &mapping, int64_t offset) { - for (auto &mappingElt : llvm::drop_begin(mapping, offset)) { - mappingElt += 1; - } - }; - - // The rest of this code prepares the inputs and a single linalg::GenericOp - // to execute the convolution. The final linalg::GenericOp will be iterated - // through based on the following eventual maps. - SmallVector srcExprs(paddedLhsType.getRank()); - SmallVector windowExprs(paddedRhsType.getRank()); - SmallVector dstExprs(reshapedResultShape.size()); - int64_t nextDim = 0; - int64_t rank = resultType.getRank(); - - auto reshapeShapeVector = [](llvm::ArrayRef oldShape, - llvm::SmallVectorImpl &newShape, - int64_t reshapedDim, int64_t factor) { - newShape.reserve(oldShape.size() + 1); - for (int64_t i : llvm::seq(0, oldShape.size())) { - if (i == reshapedDim) { - newShape.push_back(factor); - newShape.push_back(oldShape[reshapedDim] / factor); - } else { - newShape.push_back(oldShape[i]); - } - } - }; - - // If batch or feature count groupings exist, represent this through - // reshaping the input to have an additional dimension that these groupings - // exist along, and reduce in that dimension - SmallVector iterationLoops; - if (featureGroupCount != 1) { - AffineExpr parallelDim = mlir::getAffineDimExpr(nextDim++, ctx); - iterationLoops.push_back(utils::IteratorType::parallel); - // Reshape LHS - { - srcExprs.insert(srcExprs.begin() + inputFeatureDimension, parallelDim); - auto prevDimsRef = paddedLhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, inputFeatureDimension, - featureGroupCount); - updateDimMappingFromOffset(lhsIndexMapping, inputFeatureDimension); - modifiedLhs = rewriter.create( - loc, - RankedTensorType::get(newShape, paddedLhsType.getElementType()), - modifiedLhs); - } - - // Reshape RHS - { - windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension, - parallelDim); - auto prevDimsRef = paddedRhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension, - featureGroupCount); - updateDimMappingFromOffset(rhsIndexMapping, - kernelOutputFeatureDimension); - modifiedRhs = rewriter.create( - loc, - RankedTensorType::get(newShape, paddedRhsType.getElementType()), - modifiedRhs); - } - // Prepare reshaped output shape - { - dstExprs.insert(dstExprs.begin() + outputFeatureDimension, parallelDim); - updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension); - reshapedResultShape.insert(reshapedResultShape.begin() + - outputFeatureDimension, - featureGroupCount); - reshapedResultShape[outputFeatureDimension + 1] /= featureGroupCount; - } - } - - if (batchGroupCount != 1) { - iterationLoops.push_back(utils::IteratorType::parallel); - AffineExpr parallelDim = mlir::getAffineDimExpr(nextDim++, ctx); - // Reshape LHS - { - srcExprs.insert(srcExprs.begin() + inputBatchDimension, parallelDim); - ArrayRef prevDimsRef = paddedLhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, inputBatchDimension, - batchGroupCount); - updateDimMappingFromOffset(lhsIndexMapping, inputBatchDimension); - modifiedLhs = rewriter.create( - op.getLoc(), - RankedTensorType::get(newShape, paddedLhsType.getElementType()), - modifiedLhs); - } - - // Reshape RHS - { - windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension, - parallelDim); - ArrayRef prevDimsRef = paddedRhsType.getShape(); - llvm::SmallVector newShape; - reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension, - batchGroupCount); - updateDimMappingFromOffset(rhsIndexMapping, - kernelOutputFeatureDimension); - modifiedRhs = rewriter.create( - op.getLoc(), - RankedTensorType::get(newShape, paddedRhsType.getElementType()), - modifiedRhs); - } - // Prepare reshaped output shape - { - int64_t outputFeatureDim = resultIndexMapping[outputFeatureDimension]; - dstExprs.insert(dstExprs.begin() + outputFeatureDim, parallelDim); - updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension); - reshapedResultShape.insert( - reshapedResultShape.begin() + outputFeatureDim, batchGroupCount); - reshapedResultShape[outputFeatureDim + 1] /= batchGroupCount; - } - } - - // Handle input feature dimension - { - iterationLoops.push_back(utils::IteratorType::reduction); - AffineExpr inputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx); - srcExprs[lhsIndexMapping[inputFeatureDimension]] = inputFeatureDim; - windowExprs[rhsIndexMapping[kernelInputFeatureDimension]] = - inputFeatureDim; - } - - // Handle output feature dimension - { - iterationLoops.push_back(utils::IteratorType::parallel); - AffineExpr outputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx); - dstExprs[resultIndexMapping[outputFeatureDimension]] = outputFeatureDim; - windowExprs[rhsIndexMapping[kernelOutputFeatureDimension]] = - outputFeatureDim; - } - - // Handle spatial Dimensions - int64_t numSpatialDims = rank - 2; - for (int64_t i = 0; i < numSpatialDims; ++i) { - iterationLoops.push_back(utils::IteratorType::parallel); - iterationLoops.push_back(utils::IteratorType::reduction); - AffineExpr dim0 = mlir::getAffineDimExpr(nextDim++, ctx); - AffineExpr dim1 = mlir::getAffineDimExpr(nextDim++, ctx); - - AffineExpr stride = dim0; - if (op.getWindowStrides().has_value()) - stride = stride * op.getWindowStrides().value()[i]; - AffineExpr srcExpr = stride + dim1; - - srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr; - dstExprs[resultIndexMapping[outputSpatialDimensions[i]]] = dim0; - windowExprs[rhsIndexMapping[kernelSpatialDimensions[i]]] = dim1; - } - - // Handle batch dimension - { - iterationLoops.push_back(utils::IteratorType::parallel); - AffineExpr batchDim = mlir::getAffineDimExpr(nextDim++, ctx); - - srcExprs[lhsIndexMapping[inputBatchDimension]] = batchDim; - dstExprs[resultIndexMapping[outputBatchDimension]] = batchDim; - } - - // Finally, create the computation - auto inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); - - Value emptyTensor = rewriter.create( - loc, reshapedResultShape, resultType.getElementType()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - - Value convolved = - rewriter - .create( - loc, - /*resultTensors=*/ - llvm::ArrayRef(zeroTensor.getType()), - /*inputs=*/ - llvm::ArrayRef({modifiedLhs, modifiedRhs}), - /*outputs=*/llvm::ArrayRef(zeroTensor), inferredMaps, - iterationLoops, - /*bodyBuild=*/ - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange) { - ImplicitLocOpBuilder builder(nestedLoc, nestedBuilder); - linalg::Conv2DOp::regionBuilder( - builder, *builder.getInsertionBlock(), {}); - }, - linalg::getPrunedAttributeList(op)) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, - convolved); - - return success(); - } -}; - -/// Converts stablehlo.convolution operation to -/// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or -/// depthwise_conv_2d_input_nhwc_filter_hwc op. -struct DepthwiseConvolutionOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getBatchGroupCount() != 1) - return failure(); - // Fall into the normal convolution cases. - if (op.getFeatureGroupCount() == 1) - return failure(); - - const mlir::stablehlo::ConvDimensionNumbersAttr &dimensionNumbers = - op.getDimensionNumbers(); - const int64_t spatialRank = - dimensionNumbers.getInputSpatialDimensions().size(); - if (spatialRank == 0 || spatialRank > 3) { - return rewriter.notifyMatchFailure(op, "only support up to 3D for now"); - } - - // Make sure that this is depthwise convolution. - int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension(); - int64_t inputFeatureCount = - cast(op.getLhs().getType()).getDimSize(inputFeatureDim); - if (static_cast(op.getFeatureGroupCount()) != inputFeatureCount) { - return rewriter.notifyMatchFailure(op, "not depth-wise convolution"); - } - - // Make sure that this convolution has a canonical form. - if (!hasCanonicalDimensionNumbers(dimensionNumbers)) { - return rewriter.notifyMatchFailure(op, "does not have canonical form"); - } - - Attribute windowStrides; - if (op.getWindowStrides()) { - windowStrides = rewriter.getI64TensorAttr(op.getWindowStrides().value()); - } else { - windowStrides = SplatElementsAttr::get( - VectorType::get({spatialRank}, rewriter.getI64Type()), - rewriter.getI64IntegerAttr(1)); - } - - Attribute rhsDilation; - if (op.getRhsDilation()) { - rhsDilation = rewriter.getI64TensorAttr(op.getRhsDilation().value()); - } else { - rhsDilation = SplatElementsAttr::get( - VectorType::get({spatialRank}, rewriter.getI64Type()), - rewriter.getI64IntegerAttr(1)); - } - - Location loc = op.getLoc(); - Value input = adaptor.getLhs(); - Value filter = adaptor.getRhs(); - auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!resultType) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - if (!resultType.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, - "expected output has static shapes"); - } - - // Immediately emit an EmptyOp for output tensors with zero dimension. - if (llvm::is_contained(resultType.getShape(), 0)) { - rewriter.replaceOpWithNewOp(op, resultType.getShape(), - resultType.getElementType()); - return success(); - } - - // Apply padding and input dilation. - llvm::SmallVector spatialDimMapping(spatialRank); - std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1); - input = applyConvolutionPadding(loc, input, op.getPaddingAttr(), - op.getLhsDilation(), spatialDimMapping, - rewriter); - - auto filterDims = - llvm::to_vector(cast(op.getRhs().getType()).getShape()); - - auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) { - SmallVector reassociations; - int64_t rank = cast(v.getType()).getRank(); - for (int64_t i = 0; i < rank - 1; ++i) - reassociations.emplace_back(1, i); - reassociations.back().push_back(rank - 1); - return reassociations; - }; - - int64_t kernelInputFeatureDimension = - dimensionNumbers.getKernelInputFeatureDimension(); - int64_t kernelOutputFeatureDimension = - dimensionNumbers.getKernelOutputFeatureDimension(); - if (filterDims[kernelInputFeatureDimension] * - filterDims[kernelOutputFeatureDimension] != - static_cast(op.getFeatureGroupCount())) { - // For cases where channel multiplier != 1 - - // Reshaping filter shape - // [filter_height, filter_width, 1, kernel-output-feature]. - // to - // [filter_height, filter_width, feature_group_count, - // kernel-output-feature/feature_group_count ] - SmallVector reshapedFilterDims; - reshapedFilterDims.assign(filterDims.begin(), filterDims.end()); - Value reshapedFilter = filter; - if (filterDims[kernelInputFeatureDimension] == 1) { - reshapedFilterDims[kernelInputFeatureDimension] = - op.getFeatureGroupCount(); - reshapedFilterDims[kernelOutputFeatureDimension] /= - op.getFeatureGroupCount(); - auto reshapedFilterType = RankedTensorType::get( - reshapedFilterDims, - cast(op.getRhs().getType()).getElementType()); - - reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter); - } - - ArrayRef outputDims = resultType.getShape(); - int64_t channelMultiplier = reshapedFilterDims.back(); - SmallVector reshapedOutputDims; - reshapedOutputDims.assign(outputDims.begin(), outputDims.end()); - reshapedOutputDims.push_back(channelMultiplier); - reshapedOutputDims[reshapedOutputDims.size() - 2] /= channelMultiplier; - - Value emptyTensor = rewriter.create( - loc, reshapedOutputDims, resultType.getElementType()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - - auto reshapedOutputType = RankedTensorType::get( - reshapedOutputDims, resultType.getElementType()); - Value conv; - switch (spatialRank) { - case 1: { - conv = - rewriter - .create( - loc, reshapedOutputType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)) - .getResult(0); - break; - } - case 2: { - conv = - rewriter - .create( - loc, reshapedOutputType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)) - .getResult(0); - break; - } - case 3: { - conv = - rewriter - .create( - loc, reshapedOutputType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)) - .getResult(0); - break; - } - default: - llvm_unreachable("Unhandled case"); - } - - // Create a Linalg reshape op that converts the output from 5 dimensions - // into 4 dimensions (by collapsing the last two dimensions). This is - // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns - // 5 dimensions for the output. - rewriter.replaceOpWithNewOp( - op, resultType, conv, - getReassociationIndicesToCollapseLastTwoDims(conv)); - } else { - // For cases where channel multiplier == 1 - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - - // Create a Linalg reshape op that converts the filter from 4 dimensions - // into 3 dimensions (by droping the unit dimension). This is needed - // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3 - // dimensions for the filter. - - filterDims[filterDims.size() - 2] = - static_cast(op.getFeatureGroupCount()); - filterDims.pop_back(); - - RankedTensorType filterShape = - RankedTensorType::get(filterDims, op.getType().getElementType()); - - Value reshapedFilter = rewriter.create( - loc, filterShape, filter, - getReassociationIndicesToCollapseLastTwoDims(filter)); - - switch (spatialRank) { - case 1: - rewriter.replaceOpWithNewOp( - op, resultType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)); - break; - case 2: - rewriter.replaceOpWithNewOp( - op, resultType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)); - break; - case 3: - rewriter.replaceOpWithNewOp( - op, resultType, ValueRange{input, reshapedFilter}, - ValueRange{zeroTensor}, windowStrides, rhsDilation, - linalg::getPrunedAttributeList(op)); - break; - } - } - - return success(); - } -}; - -} // namespace - -namespace detail { -void populateStableHloConvolutionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - // Ensure specialized patterns are higher priority than their generic - // versions. - patterns - ->add( - typeConverter, context, PatternBenefit(2)); - - patterns->add(typeConverter, context); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgDotProd.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgDotProd.cpp deleted file mode 100644 index 0db4fafec6bc..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgDotProd.cpp +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO dot product ops to Linalg dialect. -// These patterns are separated out to their own file to save on the compilation -// times, given that we instantiate a large number of class templates here. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -enum class DotOperationType { - kVectorDot = 0, - kMatrixVector, - kVectorMatrix, - kMatrixMatrix, - kUnsupported -}; - -DotOperationType getDotOperationType(mlir::stablehlo::DotOp dotOp) { - ArrayRef lhsShape = - cast(dotOp.getLhs().getType()).getShape(); - ArrayRef rhsShape = - cast(dotOp.getRhs().getType()).getShape(); - auto shapeMatches = [](int64_t a, int64_t b) { - return ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b; - }; - if (lhsShape.size() == 1 && rhsShape.size() == 1 && - shapeMatches(lhsShape[0], rhsShape[0])) { - return DotOperationType::kVectorDot; - } - if (lhsShape.size() == 2 && rhsShape.size() == 1 && - shapeMatches(lhsShape[1], rhsShape[0])) { - return DotOperationType::kMatrixVector; - } - if (lhsShape.size() == 1 && rhsShape.size() == 2 && - shapeMatches(lhsShape[0], rhsShape[0])) { - return DotOperationType::kVectorMatrix; - } - if (lhsShape.size() == 2 && rhsShape.size() == 2 && - shapeMatches(lhsShape[1], rhsShape[0])) { - return DotOperationType::kMatrixMatrix; - } - return DotOperationType::kUnsupported; -} - -SmallVector getDotOpEmptyTensorDynSizes(OpBuilder &b, Location loc, - Value lhs, Value rhs, - DotOperationType type) { - SmallVector dynShape; - switch (type) { - case DotOperationType::kMatrixMatrix: { - if (llvm::cast(lhs.getType()).isDynamicDim(0)) - dynShape.push_back(b.create(loc, lhs, 0)); - if (llvm::cast(rhs.getType()).isDynamicDim(1)) - dynShape.push_back(b.create(loc, rhs, 1)); - break; - } - case DotOperationType::kMatrixVector: { - if (llvm::cast(lhs.getType()).isDynamicDim(0)) - dynShape.push_back(b.create(loc, lhs, 0)); - break; - } - case DotOperationType::kVectorMatrix: { - if (llvm::cast(rhs.getType()).isDynamicDim(1)) - dynShape.push_back(b.create(loc, rhs, 1)); - break; - } - case DotOperationType::kVectorDot: - case DotOperationType::kUnsupported: - break; - } - return dynShape; -} - -template -struct DotOpConversion final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = mlir::stablehlo::DotOp::Adaptor; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) { - return failure(); - } - if (getDotOperationType(op) != op_type) - return failure(); - - Location loc = op.getLoc(); - // Convert unsigned to signed. This works because signed and unsigned - // integer matmul is the same operation in two's complement. - auto outputType = - cast(getTypeConverter()->convertType(op.getType())); - SmallVector dynShape = getDotOpEmptyTensorDynSizes( - rewriter, loc, adaptor.getLhs(), adaptor.getRhs(), op_type); - Value emptyTensor = - !sparse_tensor::getSparseTensorEncoding(outputType) - ? getEmptyTensor(rewriter, loc, outputType, dynShape) - : getEmptySparseTensor(rewriter, loc, outputType, dynShape); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - rewriter.replaceOpWithNewOp( - op, TypeRange{outputType}, - ValueRange{adaptor.getLhs(), adaptor.getRhs()}, ValueRange{zeroTensor}, - linalg::getPrunedAttributeList(op)); - return success(); - } -}; - -struct DotGeneralBatchMatMulOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) { - return failure(); - } - if (llvm::cast(op.getType()).getRank() != 3) { - return rewriter.notifyMatchFailure(op, "expected a batch matmul"); - } - - mlir::stablehlo::DotDimensionNumbersAttr dimNumbers = - op.getDotDimensionNumbers(); - ArrayRef lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); - ArrayRef rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); - ArrayRef lhsContractingDims = - dimNumbers.getLhsContractingDimensions(); - ArrayRef rhsContractingDims = - dimNumbers.getRhsContractingDimensions(); - if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) { - return rewriter.notifyMatchFailure( - op, "expected lhs batching dimensions exactly {0}"); - } - if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) { - return rewriter.notifyMatchFailure( - op, "expected rhs batching dimensions exactly {0}"); - } - if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) { - return rewriter.notifyMatchFailure( - op, "expected lhs contracting dimensions exactly {2}"); - } - if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) { - return rewriter.notifyMatchFailure( - op, "expected rhs contracting dimensions exactly {1}"); - } - - Location loc = op.getLoc(); - // Convert unsigned to signed. This works because signed and unsigned - // integer matmul is the same operation in two's complement. - auto outputType = - cast(typeConverter->convertType(op.getType())); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - Operation *linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/TypeRange{outputType}, - /*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()}, - /*outputBuffers=*/ValueRange{zeroTensor}, - linalg::getPrunedAttributeList(op)); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; - -struct DotGeneralOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (failed(verifyHloOpBufferOrTensorSemantics(op))) { - return failure(); - } - - // Get various dimension iterator information - mlir::stablehlo::DotDimensionNumbersAttr dimNumbers = - op.getDotDimensionNumbers(); - ArrayRef lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); - ArrayRef rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); - ArrayRef lhsContractingDims = - dimNumbers.getLhsContractingDimensions(); - ArrayRef rhsContractingDims = - dimNumbers.getRhsContractingDimensions(); - - // Get shape information and initialize output - assert(lhsContractingDims.size() == rhsContractingDims.size() && - "number of contracting dims must be equal"); - size_t numContracting = lhsContractingDims.size(); - // Convert unsigned to signed. This works because signed and unsigned - // integer matmul is the same operation in two's complement. - auto outputType = - cast(typeConverter->convertType(op.getType())); - size_t targetRank = outputType.getRank(); - size_t totalLoopCount = numContracting + targetRank; - - int64_t lhsRank = - llvm::cast(adaptor.getLhs().getType()).getRank(); - size_t lhsExtraDims = - lhsRank - lhsBatchingDims.size() - lhsContractingDims.size(); - int64_t rhsRank = - llvm::cast(adaptor.getRhs().getType()).getRank(); - - Location loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); - Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); - SmallVector indexingMaps; - - auto getMap = [&](int64_t rank, ArrayRef batchingDims, - ArrayRef contractingDims, size_t extraDims) { - llvm::SmallVector indices(rank); - for (const auto &i : llvm::enumerate(batchingDims)) { - indices[i.value()] = rewriter.getAffineDimExpr(i.index()); - } - for (const auto &i : llvm::enumerate(contractingDims)) { - indices[i.value()] = rewriter.getAffineDimExpr(i.index() + targetRank); - } - for (int i = 0; i < rank; ++i) { - if (!indices[i]) { - indices[i] = rewriter.getAffineDimExpr(extraDims++); - } - } - indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount, - /*symbolCount=*/0, indices, - op->getContext())); - }; - getMap(lhsRank, lhsBatchingDims, lhsContractingDims, - lhsBatchingDims.size()); - getMap(rhsRank, rhsBatchingDims, rhsContractingDims, - rhsBatchingDims.size() + lhsExtraDims); - - { - SmallVector dimExprs; - dimExprs.reserve(targetRank); - for (unsigned i = 0; i < targetRank; ++i) - dimExprs.push_back(rewriter.getAffineDimExpr(i)); - indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount, - /*symbolCount=*/0, dimExprs, - op.getContext())); - } - - Operation *linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/TypeRange{outputType}, - /*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()}, - /*outputBuffers=*/ValueRange{zeroTensor}, indexingMaps, - getParallelAndReductionIterators( - /*nLoops=*/totalLoopCount, - /*nReduction=*/numContracting), - [](OpBuilder &b, Location loc, ValueRange) { - ImplicitLocOpBuilder builder(loc, b); - linalg::MatmulOp::regionBuilder(builder, *b.getInsertionBlock(), {}); - }, - linalg::getPrunedAttributeList(op)); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; - -} // namespace - -namespace detail { -void populateStableHloDotProdToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - // Ensure specialized patterns are higher priority than their generic - // versions. - patterns - ->add, - DotOpConversion, - DotOpConversion, - DotOpConversion, - DotGeneralBatchMatMulOpConversion>(typeConverter, context, - PatternBenefit(2)); - patterns->add(typeConverter, context, - PatternBenefit(1)); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgPointwise.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgPointwise.cpp deleted file mode 100644 index 6f2985abd50b..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgPointwise.cpp +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO pointwise ops to Linalg dialect. -// These patterns are separated out to their own file to save on the compilation -// times, given that we instantiate a large number of class templates here. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -int64_t getRank(Value v) { return cast(v.getType()).getRank(); } - -int64_t getMaxRank(ValueRange operands) { - int64_t maxRank = 0; - for (Value operand : operands) { - maxRank = std::max(maxRank, getRank(operand)); - } - return maxRank; -} - -bool isScalar(Value v) { return getRank(v) == 0; } - -/// Inserts block arguments in places where scalar inputs have a nullptr. -SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, - ValueRange blockArgs) { - SmallVector result; - auto argsIter = blockArgs.begin(); - for (Value scalarInput : scalarInputs) { - if (scalarInput) { - result.push_back(scalarInput); - } else { - result.push_back(*argsIter); - ++argsIter; - } - } - return result; -} - -struct PointwiseConversionInfo { - int64_t maxOperandRank = 0; - ShapedType resultType; -}; - -/// Checks the preconditions for conversion of pointwise HLO ops to linalg. -/// Returns the max operand rank and the result type on success. -FailureOr -checkOperandsAndResults(Operation *op, ValueRange operands, - const TypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { - int64_t maxRank = getMaxRank(operands); - - // Apply only if all operands are scalar or have the same rank. Some ops, - // like `stablehlo.select`, support implicit broadcasting of scalars. - if (!llvm::all_of(operands, [&](Value v) { - int64_t r = getRank(v); - return r == 0 || r == maxRank; - })) { - return rewriter.notifyMatchFailure( - op, "Operands must be of same rank or scalar."); - } - - // Find result type, if on tensors. - auto resultTy = dyn_cast_or_null( - typeConverter.convertType(op->getResultTypes().front())); - - // Check result type compatibility. - if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || - !(resultTy.getElementType().isSignlessIntOrFloat() || - isa(resultTy.getElementType()))) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); - } - - // All-scalar pointwise ops inside of linalg ops are processes by - // ScalarHloToArithmeticPattern. - if (maxRank == 0 && isInBodyOfLinalgOps(op)) - return failure(); - - return PointwiseConversionInfo{maxRank, resultTy}; -} - -/// Converts a HLO operation to a linalg.map op that contains the corresponding -/// scalar operations. -template -struct PointwiseToLinalgMapConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto conversionInfo = checkOperandsAndResults( - op, adaptor.getOperands(), *this->typeConverter, rewriter); - if (failed(conversionInfo)) { - return failure(); - } - - int64_t maxRank = conversionInfo->maxOperandRank; - ShapedType resultTy = conversionInfo->resultType; - Location loc = op.getLoc(); - - // Find input/output values and types. - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - // Mapped inputs are cast to the same shape as the init tensor. - // Values from scalar inputs are extracted and used directly in the block. - SmallVector mappedInputs; - SmallVector scalarInputs; - for (Value input : adaptor.getOperands()) { - if (getRank(input) == maxRank) { - mappedInputs.push_back(coerceTensorShape( - rewriter, loc, cast>(input), - cast(emptyTensor.getType()))); - scalarInputs.push_back(nullptr); - } else { - scalarInputs.push_back(rewriter.create(loc, input)); - } - } - - auto mapOp = rewriter.create( - loc, mappedInputs, emptyTensor, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, getElementTypeOrSelf(emptyTensor), - interleaveScalarAndBlockArgs(scalarInputs, args), &b); - - b.create(loc, innerResult); - }, - linalg::getPrunedAttributeList(op)); - - rewriter.replaceOp(op, mapOp->getResults()); - return success(); - } -}; - -/// Converts a HLO operation to a linalg.generic op that contains the -/// corresponding scalar operations. -template -struct PointwiseToLinalgConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpTy::Adaptor; - - LogicalResult - matchAndRewrite(OpTy op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto conversionInfo = checkOperandsAndResults( - op, adaptor.getOperands(), *this->typeConverter, rewriter); - if (failed(conversionInfo)) { - return failure(); - } - - int64_t maxRank = conversionInfo->maxOperandRank; - ShapedType resultTy = conversionInfo->resultType; - Location loc = op.getLoc(); - - // Find input/output values and types. - ValueRange inputs = adaptor.getOperands(); - Value output = - getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - - // Create indexing maps. - AffineMap scalarMap = AffineMap::get(maxRank, 0, rewriter.getContext()); - AffineMap idMap = rewriter.getMultiDimIdentityMap(maxRank); - SmallVector maps; - for (Value v : inputs) - maps.push_back(isScalar(v) ? scalarMap : idMap); - maps.push_back(idMap); - - // Build `linalg.generic` op. - bool failed = false; - auto linalgOp = rewriter.create( - loc, resultTy ? resultTy : TypeRange{}, inputs, output, maps, - getNParallelLoopsAttrs(maxRank), - [&](OpBuilder &nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - Type innerResultTy = getElementTypeOrSelf(output); - auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); - Value semiring = preSparsify(op, argvec, innerResultTy, &rewriter); - Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, innerResultTy, argvec, &rewriter); - if (!innerResult) { - failed = true; - } else { - innerResult = postSparsify(op, semiring, innerResult, &rewriter); - nestedBuilder.create(loc, innerResult); - } - }, - linalg::getPrunedAttributeList(op)); - if (failed) - return failure(); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; -} // namespace - -namespace detail { -void populatePointwiseStableHloToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps) { - if (enablePrimitiveOps) { - patterns->add< - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter>(typeConverter, - context); - return; - } - - patterns - ->add, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter>(typeConverter, - context); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgRandom.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgRandom.cpp deleted file mode 100644 index e1e73ed75bbc..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgRandom.cpp +++ /dev/null @@ -1,917 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO random number generation to Linalg -// dialect. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -class ArithOpBuilder { -public: - ArithOpBuilder(OpBuilder b, Location l, Value v) - : builder(b), loc(l), value(v) {} - - explicit operator Value() { return value; } - Value val() { return value; } - - ArithOpBuilder constantI(int64_t value, int64_t bits) { - Value val = builder.create( - loc, builder.getIntegerAttr(builder.getIntegerType(bits), value)); - return ArithOpBuilder(builder, loc, val); - } - - ArithOpBuilder extendUI(int32_t bits) { - Value ext = builder.create( - loc, builder.getIntegerType(bits), value); - return ArithOpBuilder(builder, loc, ext); - } - - ArithOpBuilder truncI(int64_t bits) { - if (value.getType().getIntOrFloatBitWidth() == bits) - return *this; - Value trunc = builder.create( - loc, builder.getIntegerType(bits), value); - return ArithOpBuilder(builder, loc, trunc); - } - - ArithOpBuilder linalgIndex(int32_t index) { - Value val = builder.create(loc, index); - return ArithOpBuilder(builder, loc, val); - } - - ArithOpBuilder indexCast(int32_t bitwidth) { - if (isa(value.getType())) { - Value cast = builder.create( - loc, builder.getIndexType(), value); - return ArithOpBuilder(builder, loc, cast); - } - - Value cast = builder.create( - loc, builder.getIntegerType(bitwidth), value); - return ArithOpBuilder(builder, loc, cast); - } - - ArithOpBuilder rotateLeft(int32_t rotation) { - int32_t bits = value.getType().getIntOrFloatBitWidth(); - ArithOpBuilder cLeft = constantI(rotation, bits); - ArithOpBuilder cRight = constantI(bits - rotation, bits); - ArithOpBuilder rLeft = (*this << cLeft); - ArithOpBuilder rRight = (*this >> cRight); - return rLeft | rRight; - } - - ArithOpBuilder operator+(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator*(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator|(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator^(ArithOpBuilder &rhs) { - Value res = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, res); - } - - ArithOpBuilder operator<<(ArithOpBuilder &rhs) { - Value shl = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, shl); - } - - ArithOpBuilder operator>>(ArithOpBuilder &rhs) { - Value shr = builder.create(loc, value, rhs.value); - return ArithOpBuilder(builder, loc, shr); - } - -private: - OpBuilder builder; - Location loc; - Value value; -}; - -std::pair splitI64(ArithOpBuilder i64) { - auto low = i64.truncI(32); - auto c32 = i64.constantI(/*value=*/32, /*bits=*/64); - auto high = (i64 >> c32).truncI(32); - return {low, high}; -} - -ArithOpBuilder fuseI32s(ArithOpBuilder low, ArithOpBuilder high) { - auto c32 = high.constantI(/*value=*/32, /*bits=*/64); - high = high.extendUI(64) << c32; - low = low.extendUI(64); - return low | high; -} - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -std::pair -runThreeFry2xi32(ArithOpBuilder key0, ArithOpBuilder key1, - ArithOpBuilder initialState) { - ArithOpBuilder index = initialState.linalgIndex(0); - index = index.indexCast(64); - index = index + initialState; - - // Split into the 2xi32 used for threefry. - std::pair input = splitI64(index); - ArithOpBuilder input0 = input.first; - ArithOpBuilder input1 = input.second; - - // Magic number and rotation distances specified by the Threefry2x32 - // algorithm. - llvm::SmallVector rotations = {13, 15, 26, 6, 17, 29, 16, 24}; - ArithOpBuilder magic = key0.constantI(/*value=*/0x1bd11bda, /*bits=*/32); - - ArithOpBuilder key2 = magic ^ key0 ^ key1; - std::array ks{key0, key1, key2}; - std::array x{input0 + key0, input1 + key1}; - - // Performs a single round of the Threefry2x32 algorithm, with a rotation - // amount 'rotation'. - for (int i = 0; i < 5; ++i) { - int32_t rot = (4 * i) % rotations.size(); - int32_t k1 = (i + 1) % ks.size(); - int32_t k2 = (i + 2) % ks.size(); - - for (int j = 0; j < 4; ++j) { - x[0] = x[0] + x[1]; - x[1] = x[1].rotateLeft(rotations[rot + j]); - x[1] = x[0] ^ x[1]; - } - - ArithOpBuilder c = x[0].constantI(/*value=*/i + 1, /*bits=*/32); - x[0] = x[0] + ks[k1]; - x[1] = x[1] + ks[k2]; - x[1] = x[1] + c; - } - - return std::pair(x[0], x[1]); -} - -// Extract and potentially reconstruct the i32 key-pair as necessary. -std::pair extractKey32(OpBuilder &builder, Location loc, - Value store) { - auto storeTy = cast(store.getType()); - if (storeTy.getRank() != 1) - return {nullptr, nullptr}; - - Type storeETy = storeTy.getElementType(); - IntegerType i32Ty = builder.getIntegerType(32); - IntegerType i64Ty = builder.getIntegerType(64); - - if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { - Value idx0 = builder.create(loc, 0); - Value idx1 = builder.create(loc, 1); - Value key0 = builder.create(loc, store, idx0); - Value key1 = builder.create(loc, store, idx1); - key0 = builder.create(loc, i32Ty, key0); - key1 = builder.create(loc, i32Ty, key1); - return {key0, key1}; - } - - if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 0); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - auto pair = splitI64(ArithOpBuilder(builder, loc, cast)); - return std::pair(pair.first, pair.second); - } - - // TODO(#14859): Properly handle 128-bit storage keys. - if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 0); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - auto pair = splitI64(ArithOpBuilder(builder, loc, cast)); - return std::pair(pair.first, pair.second); - } - - return {nullptr, nullptr}; -} - -// Extract and potentially reconstruct the i64 state as necessary. -Value extractState64(OpBuilder &builder, Location loc, Value store) { - auto storeTy = cast(store.getType()); - if (storeTy.getRank() != 1) - return nullptr; - - Type storeETy = storeTy.getElementType(); - IntegerType i64Ty = builder.getIntegerType(64); - - if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 1); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - return cast; - } - - // TODO(#14859): Properly handle 128-bit storage keys. - if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { - Value idx1 = builder.create(loc, 1); - Value state = builder.create(loc, store, idx1); - Value cast = builder.create(loc, i64Ty, state); - return cast; - } - - if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { - Value idx2 = builder.create(loc, 2); - Value idx3 = builder.create(loc, 3); - - Value low = builder.create(loc, store, idx2); - Value high = builder.create(loc, store, idx3); - - ArithOpBuilder i64 = fuseI32s(ArithOpBuilder(builder, loc, high), - ArithOpBuilder(builder, loc, low)); - return builder.create(loc, i64Ty, i64.val()); - } - - return nullptr; -} - -Value setState64(OpBuilder &b, Location loc, Value store, Value state) { - auto storeTy = cast(store.getType()); - if (storeTy.getRank() != 1) - return nullptr; - - Type storeETy = storeTy.getElementType(); - - if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { - state = b.create(loc, storeETy, state); - Value idx1 = b.create(loc, 1); - return b.create(loc, storeTy, state, store, - ValueRange{idx1}); - } - - // TODO(#14859): Properly handle 128-bit storage keys. - if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { - state = b.create(loc, storeETy, state); - Value idx1 = b.create(loc, 1); - return b.create(loc, storeTy, state, store, - ValueRange{idx1}); - } - - if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { - Value idx2 = b.create(loc, 2); - Value idx3 = b.create(loc, 3); - std::pair states = - splitI64(ArithOpBuilder(b, loc, state)); - Value state0 = - b.create(loc, storeETy, states.first.val()); - Value state1 = - b.create(loc, storeETy, states.second.val()); - Value insert0 = b.create(loc, storeTy, state0, store, - ValueRange{idx2}); - Value insert1 = b.create(loc, storeTy, state1, insert0, - ValueRange{idx3}); - return insert1; - } - - return nullptr; -} - -Value reshapeToTarget(OpBuilder &builder, Location loc, ShapedType destTy, - Value src) { - auto srcTy = cast(src.getType()); - // Expand out to the target shape. - - auto reassociationIndices = - getReassociationIndicesForCollapse(destTy.getShape(), srcTy.getShape()); - if (reassociationIndices.has_value()) { - src = builder.create(loc, destTy, src, - reassociationIndices.value()); - } - - // It is also possible our target is Rank-0, then we would - // need to collapse. - reassociationIndices = - getReassociationIndicesForCollapse(srcTy.getShape(), destTy.getShape()); - if (reassociationIndices.has_value()) { - src = builder.create(loc, destTy, src, - reassociationIndices.value()); - } - - return src; -} - -// Compute the shape for computing three fry. -std::pair threeFry32Shape(ShapedType resultTy) { - if (resultTy.getRank() == 0) { - return {resultTy, 0}; - } - - ArrayRef shape = resultTy.getShape(); - uint64_t halfDim = - std::max_element(shape.begin(), shape.end()) - shape.begin(); - - for (int i = 0, s = shape.size(); i < s; i++) { - if (shape[i] & 0x1) - continue; - halfDim = i; - break; - } - - llvm::SmallVector newShape(shape); - newShape[halfDim] = (newShape[halfDim] + 1) / 2; - if (halfDim == (newShape.size() - 1)) { - newShape.push_back(1); - } - - return {RankedTensorType::get(newShape, resultTy.getElementType()), halfDim}; -} - -/// This implementation generates a 32-bit tensor of ThreeFry random numbers. -/// It matches the XLA implementation bit-exact and includes an inefficient -/// method of concatenating / slicing the pairs of generated numbers. -/// -/// We should consider dropping the complex slicing and simply generating -/// 2x the values, then downcast to a 32-bit. It substantially simplifies -/// the computation and avoids the concat / slice behavior. -LogicalResult generateLinalgThreeFry32(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - - // Extract the stateful values as an i64 and increment the state ahead. - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - ArithOpBuilder key0(builder, loc, keys.first); - ArithOpBuilder key1(builder, loc, keys.second); - - // Compute the intermediate type we use to compute three fry values, including - // the dimension that was halved. - auto pair = threeFry32Shape(resultTy); - ShapedType intermediateType = pair.first; - int64_t halfDim = pair.second; - int64_t count = intermediateType.getNumElements(); - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // Generate a 1D tensor with for the random values. - Value destLeft = builder.create( - loc, ArrayRef({count}), resultETy); - Value destRight = builder.create( - loc, ArrayRef({count}), resultETy); - - ShapedType destTy = llvm::cast(destLeft.getType()); - - SmallVector indexingMaps(2, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - linalg::GenericOp generic = builder.create( - loc, TypeRange{destTy, destTy}, - /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{destLeft, destRight}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - // Grab three fry results and write to each array. - auto split = runThreeFry2xi32( - key0, key1, ArithOpBuilder(b, nestedLoc, initialState)); - auto first = split.first.truncI(resultETy.getIntOrFloatBitWidth()); - auto second = split.second.truncI(resultETy.getIntOrFloatBitWidth()); - b.create(loc, ValueRange{first.val(), second.val()}); - }); - - if (resultTy.getNumElements() == 1) { - result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); - store = setState64(builder, loc, store, newState); - return success(); - } - - // Reshape to the target size and concatenate on the dimension following the - // half dimension. - Value random0 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); - Value random1 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); - Value concatenate = builder.create( - loc, ValueRange{random0, random1}, - builder.getI64IntegerAttr(halfDim + 1)); - - // Collapse the concat dimension back into the parent. - llvm::SmallVector collapseShape(resultTy.getShape()); - collapseShape[halfDim] = - collapseShape[halfDim] + (collapseShape[halfDim] & 1); - Value reshape = builder.create( - loc, resultTy.clone(collapseShape), concatenate); - - // Slice to only the required results. - llvm::SmallVector offset(resultTy.getRank(), 0); - llvm::SmallVector stride(resultTy.getRank(), 1); - Value slice = builder.create( - loc, resultTy, reshape, builder.getDenseI64ArrayAttr(offset), - builder.getDenseI64ArrayAttr(resultTy.getShape()), - builder.getDenseI64ArrayAttr(stride)); - - // Set the new tensor values. - store = setState64(builder, loc, store, newState); - result = slice; - - return success(); -} - -LogicalResult generateLinalgThreeFry64(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - int64_t count = resultTy.getNumElements(); - - // Extract the stateful values as an i64 and increment the state ahead. - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - ArithOpBuilder key0(builder, loc, keys.first); - ArithOpBuilder key1(builder, loc, keys.second); - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // Generate a 1D tensor with for the random values. - Value dest = builder.create(loc, ArrayRef({count}), - resultETy); - ShapedType destTy = llvm::cast(dest.getType()); - - SmallVector indexingMaps(1, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - auto random = builder.create( - loc, destTy, /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{dest}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - // Generate three fry results, fuse, and return an - // i64. - auto split = runThreeFry2xi32( - key0, key1, ArithOpBuilder(b, nestedLoc, initialState)); - Value result = fuseI32s(split.first, split.second).val(); - b.create(nestedLoc, result); - }); - - store = setState64(builder, loc, store, newState); - result = reshapeToTarget(builder, loc, resultTy, random.getResult(0)); - return success(); -} - -using PhiloxKey = std::pair; -using PhiloxState = std::array; - -// Computes high and low words from multiplying 32 bit integers. -// Per the paper, mulhi and mullo of the same arguments can be computed -// Simultaneously in a single instruction on x86 architectures. -std::pair multiplyHilo(ArithOpBuilder counter, - ArithOpBuilder key) { - counter = counter.extendUI(64); - key = key.extendUI(64); - ArithOpBuilder product = counter * key; - ArithOpBuilder ci64 = counter.constantI(/*value=*/32, /*bits=*/64); - ArithOpBuilder hi = product >> ci64; - hi = hi.truncI(32); - product = product.truncI(32); - return std::pair{hi, product}; -} - -PhiloxState philoxRound(PhiloxState x, PhiloxKey key) { - // These are philox specific constants. - ArithOpBuilder m0 = x[0].constantI(0xD2511F53, 32); - ArithOpBuilder m1 = x[2].constantI(0xCD9E8D57, 32); - std::pair p0 = multiplyHilo(x[0], m0); - std::pair p1 = multiplyHilo(x[2], m1); - - PhiloxState state = {p1.first ^ x[1] ^ key.first, p1.second, - p0.first ^ x[3] ^ key.second, p0.second}; - return state; -} - -PhiloxKey raiseKey(PhiloxKey key) { - // These are philox specific constants. - ArithOpBuilder w0 = key.first.constantI(0x9E3779B9, 32); - ArithOpBuilder w1 = key.first.constantI(0xBB67AE85, 32); - return PhiloxKey{key.first + w0, key.second + w1}; -} - -// Implements the Philox 4x32 counter-based PRNG algorithm. -// The Philox PRNG has been proposed in: -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -std::array runPhilox4x32(PhiloxKey key, - ArithOpBuilder state) { - ArithOpBuilder index = state.linalgIndex(0); - index = index.indexCast(64); - index = index + state; - - // Split into the 2xi32 used for threefry. - std::pair input = splitI64(index); - ArithOpBuilder input0 = input.first; - ArithOpBuilder input1 = input.second; - - // We initialize the state as such to match the XLA implementation. - PhiloxState state4 = {input0, input1, key.first, key.second}; - - // We perform 10 rounds to match the XLA implementation. - constexpr int kNumRounds = 10; - for (int round = 0; round < kNumRounds; ++round, key = raiseKey(key)) { - state4 = philoxRound(state4, key); - } - return state4; -} - -// Generates an array of primitive type U32 with the given shape containing -// random bits generated by the Philox algorithm. Returns the array and the new -// state of the random number generator. -LogicalResult generateLinalgPhilox32(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - int64_t numElements = resultTy.getNumElements(); - int64_t count = (numElements + 3) / 4; - ShapedType intermediateType = - RankedTensorType::get({count, 1}, resultTy.getElementType()); - int64_t concatDim = 1; - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // set up four outputs - Value dest0 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest1 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest2 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest3 = builder.create(loc, ArrayRef({count}), - resultETy); - - ShapedType destTy = cast(dest0.getType()); - - SmallVector indexingMaps(4, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - linalg::GenericOp generic = builder.create( - loc, TypeRange{destTy, destTy, destTy, destTy}, - /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{dest0, dest1, dest2, dest3}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - auto output = - runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first), - ArithOpBuilder(b, nestedLoc, keys.second)}, - ArithOpBuilder(b, nestedLoc, initialState)); - auto out0 = output[0].truncI(resultETy.getIntOrFloatBitWidth()); - auto out1 = output[1].truncI(resultETy.getIntOrFloatBitWidth()); - auto out2 = output[2].truncI(resultETy.getIntOrFloatBitWidth()); - auto out3 = output[3].truncI(resultETy.getIntOrFloatBitWidth()); - b.create( - loc, ValueRange{out0.val(), out1.val(), out2.val(), out3.val()}); - }); - - if (resultTy.getNumElements() == 1) { - result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); - store = setState64(builder, loc, store, newState); - return success(); - } - - Value r0 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); - Value r1 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); - Value r2 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(2)); - Value r3 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(3)); - - Value concatenate = builder.create( - loc, ValueRange{r0, r1, r2, r3}, builder.getI64IntegerAttr(concatDim)); - - // Collapse the concat dimension back into the parent. - llvm::SmallVector collapseShape(intermediateType.getShape()); - collapseShape[0] = collapseShape[0] * 4; - Value reshapeIntermediate = builder.create( - loc, resultTy.clone(collapseShape), concatenate); - - // Slice to only the required results. - collapseShape[0] = resultTy.getNumElements(); - - auto sliceResultTy = intermediateType.clone(collapseShape); - llvm::SmallVector offset(sliceResultTy.getRank(), 0); - llvm::SmallVector stride(sliceResultTy.getRank(), 1); - Value slice = builder.create( - loc, sliceResultTy, reshapeIntermediate, - builder.getDenseI64ArrayAttr(offset), - builder.getDenseI64ArrayAttr(collapseShape), - builder.getDenseI64ArrayAttr(stride)); - Value reshapeResult = - builder.create(loc, resultTy, slice); - - // Set the new tensor values. - store = setState64(builder, loc, store, newState); - result = reshapeResult; - - return success(); -} - -LogicalResult generateLinalgPhilox64(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &store, - Value &result) { - Type resultETy = resultTy.getElementType(); - - Value initialState = extractState64(builder, loc, store); - if (!initialState) - return failure(); - - std::pair keys = extractKey32(builder, loc, store); - if (!keys.first || !keys.second) - return failure(); - - int64_t numElements = resultTy.getNumElements(); - int64_t count = (numElements + 1) / 2; - ShapedType intermediateType = - RankedTensorType::get({count, 1}, resultTy.getElementType()); - int64_t concatDim = 1; - - // Compute the number of random i64s generated and increment state. - Value countVal = - builder.create(loc, builder.getI64IntegerAttr(count)); - Value newState = builder.create(loc, initialState, countVal); - - // set up four outputs - Value dest0 = builder.create(loc, ArrayRef({count}), - resultETy); - Value dest1 = builder.create(loc, ArrayRef({count}), - resultETy); - ShapedType destTy = cast(dest0.getType()); - - SmallVector indexingMaps(2, builder.getMultiDimIdentityMap(1)); - SmallVector iterators(1, utils::IteratorType::parallel); - - linalg::GenericOp generic = builder.create( - loc, TypeRange{destTy, destTy}, - /*inputs=*/ValueRange(), - /*outputs=*/ValueRange{dest0, dest1}, - /*indexingMaps=*/indexingMaps, iterators, - [&](OpBuilder &b, Location nestedLoc, ValueRange) { - auto output = - runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first), - ArithOpBuilder(b, nestedLoc, keys.second)}, - ArithOpBuilder(b, nestedLoc, initialState)); - auto out0 = output[0]; - auto out1 = output[1]; - auto out2 = output[2]; - auto out3 = output[3]; - Value result1 = fuseI32s(out0, out1).val(); - Value result2 = fuseI32s(out2, out3).val(); - b.create(loc, ValueRange{result1, result2}); - }); - - if (resultTy.getNumElements() == 1) { - result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); - store = setState64(builder, loc, store, newState); - return success(); - } - - Value r0 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); - Value r1 = - reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); - Value concatenate = builder.create( - loc, ValueRange{r0, r1}, builder.getI64IntegerAttr(concatDim)); - - // Collapse the concat dimension back into the parent. - llvm::SmallVector collapseShape(intermediateType.getShape()); - collapseShape[0] = collapseShape[0] * 2; - Value reshapeIntermediate = builder.create( - loc, resultTy.clone(collapseShape), concatenate); - - // Slice to only the required results. - collapseShape[0] = resultTy.getNumElements(); - - auto sliceResultTy = intermediateType.clone(collapseShape); - llvm::SmallVector offset(sliceResultTy.getRank(), 0); - llvm::SmallVector stride(sliceResultTy.getRank(), 1); - Value slice = builder.create( - loc, sliceResultTy, reshapeIntermediate, - builder.getDenseI64ArrayAttr(offset), - builder.getDenseI64ArrayAttr(collapseShape), - builder.getDenseI64ArrayAttr(stride)); - Value reshapeResult = - builder.create(loc, resultTy, slice); - - // Set the new tensor values. - store = setState64(builder, loc, store, newState); - result = reshapeResult; - - return success(); -} - -LogicalResult generateLinalgThreeFry(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &state, - Value &result) { - Type eTy = resultTy.getElementType(); - unsigned bitwidth = eTy.getIntOrFloatBitWidth(); - - if (bitwidth == 64) { - return generateLinalgThreeFry64(builder, loc, resultTy, state, result); - } - if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) { - return generateLinalgThreeFry32(builder, loc, resultTy, state, result); - } - - return failure(); -} - -LogicalResult generateLinalgPhilox(OpBuilder &builder, Location loc, - ShapedType resultTy, Value &state, - Value &result) { - Type eTy = resultTy.getElementType(); - unsigned bitwidth = eTy.getIntOrFloatBitWidth(); - if (bitwidth == 64) { - return generateLinalgPhilox64(builder, loc, resultTy, state, result); - } - - // The 32 bit implementation trancates to result eTy. - if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) { - return generateLinalgPhilox32(builder, loc, resultTy, state, result); - } - - return failure(); -} - -struct RngBitGeneratorConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::RngBitGeneratorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value state = adaptor.getInitialState(); - auto resultTy = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult(1).getType())); - if (!resultTy) { - return rewriter.notifyMatchFailure(op, "type conversion failed"); - } - - if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::THREE_FRY) { - Value random; - if (failed( - generateLinalgThreeFry(rewriter, loc, resultTy, state, random))) { - return failure(); - } - rewriter.replaceOp(op, {state, random}); - return success(); - } - - if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::PHILOX || - op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::DEFAULT) { - Value random; - if (failed( - generateLinalgPhilox(rewriter, loc, resultTy, state, random))) { - return failure(); - } - rewriter.replaceOp(op, {state, random}); - return success(); - } - - return failure(); - } -}; - -struct RngUniformConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::RngOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // We only handle uniform distributions. - if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::UNIFORM) { - return failure(); - } - // TODO(raikonenfnu): Handle other element types as well. - auto minTy = dyn_cast(adaptor.getA().getType()); - auto maxTy = dyn_cast(adaptor.getB().getType()); - if (!isa(minTy.getElementType()) || - !isa(maxTy.getElementType())) { - return rewriter.notifyMatchFailure( - op, "expected min/max for rng op to be FloatType"); - } - auto targetTy = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); - if (!targetTy) { - return rewriter.notifyMatchFailure( - op, "expected target shape of rng op to be ShapedType"); - } - auto loc = op.getLoc(); - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, targetTy, op, adaptor.getOperands()); - // Creates index map using target matrix's rank. - auto targetRank = targetTy.getRank(); - SmallVector indexingMaps( - 2, AffineMap::get(targetRank, /*symbolCount=*/0, - SmallVector({}), rewriter.getContext())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(targetRank)); - const int kInitialSeed = 0; - - // Generic region with LCG Algorithm that make use of element index from: - // https://reviews.llvm.org/D101364 - auto linalgOp = rewriter.create( - loc, /*resultTensors=*/targetTy, - /*inputs=*/ - ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]}, - /*outputs=*/emptyTensor, indexingMaps, - getParallelAndReductionIterators(/*nLoops=*/targetRank, - /*nReduction=*/0), - [&](OpBuilder &b, Location loc, ValueRange args) { - llvm::SmallVector updateVec = {b.create( - loc, b.getI32IntegerAttr(kInitialSeed))}; - Value multiplier = - b.create(loc, b.getI32IntegerAttr(1103515245)); - Value incrementStep = - b.create(loc, b.getI32IntegerAttr(12345)); - // For output matrix with rank N: - // temp1 = (cast(I32, index(D.0)) + seed) * mult + incr - // ... - // tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr - for (int i = 0; i < targetRank; i++) { - Value update = updateVec.back(); - Value ind = b.create(loc, i); - Value castInd = - b.create(loc, b.getI32Type(), ind); - Value addRes = b.create(loc, castInd, update); - Value multRes = b.create(loc, addRes, multiplier); - Value incRes = b.create(loc, multRes, incrementStep); - updateVec.push_back(incRes); - } - // Scaling = (max - min) * const(F64, 2.3283064E-10) - // which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)). - Value epsilon = b.create( - loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10)); - Value range = b.create(loc, args[1], args[0]); - Value scale = b.create(loc, range, epsilon); - // Res = cast(T, cast(F64, tempN) * scaling + min) - Value updateCast = b.create( - loc, targetTy.getElementType(), updateVec.back()); - Value scaleUpdate = b.create(loc, updateCast, scale); - Value res = b.create(loc, scaleUpdate, args[0]); - b.create(loc, res); - }, - linalg::getPrunedAttributeList(op)); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; -} // namespace - -namespace detail { -void populateStableHloRandomToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns) { - patterns->add(typeConverter, - context); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp deleted file mode 100644 index ce8ba53ac83c..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp +++ /dev/null @@ -1,723 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements logic for lowering StableHLO reduction ops to Linalg dialect. -// These patterns are separated out to their own file to save on the compilation -// times. - -#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" -#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::stablehlo { -namespace { -/// Returns true when reduction `op` is not supported and should be filtered -/// out. -static bool isUnsupported(mlir::stablehlo::ReduceOp op) { - // Empty reductions are not supported. We expect canonicalization patterns to - // handle them. - if (op.getDimensions().empty()) - return true; - - // We require all reduce shapes to be the same, up to the element types, so - // we can just the first operand and the first result as a representative. - if (auto inputTy = - dyn_cast(op.getInputs().getType().front())) { - return llvm::is_contained(inputTy.getShape(), 0); - } - - return false; -} - -/// Returns a permutation AffineMap that puts all reduction dimensions to the -/// last. The order of parallel loops and reduction loops are all sorted. E.g., -/// if `rank` is 4 and `reductionDims` is {1, 3}, then -/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of -/// the AffineMap is returned. -AffineMap getTransposeMapForReduction(MLIRContext *context, int rank, - ArrayRef reductionDims) { - llvm::SmallSetVector s(reductionDims.begin(), reductionDims.end()); - - SmallVector permutation; - for (int i = 0; i < rank; ++i) { - if (!s.contains(i)) { - permutation.push_back(i); - } - } - - llvm::append_range(permutation, reductionDims); - auto map = AffineMap::getPermutationMap(permutation, context); - return inversePermutation(map); -} - -SmallVector -getReduceOpEmptyTensorDynSizes(OpBuilder &b, Location loc, Value arg, - ShapedType resultType, - ArrayRef reductionDims) { - llvm::SmallSetVector s(reductionDims.begin(), reductionDims.end()); - - SmallVector parallelDims; - SmallVector dynShape; - int rank = cast(arg.getType()).getRank(); - for (int i = 0, j = 0; i < rank; ++i) { - if (s.contains(i)) - continue; - if (!resultType.isDynamicDim(j++)) - continue; - dynShape.push_back(b.create(loc, arg, i)); - } - - return dynShape; -} - -struct ReduceRegionReturnOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isInBodyOfLinalgOps(op)) { - return failure(); - } - - SmallVector operands(adaptor.getOperands()); - for (Value &operand : operands) { - if (isa(operand.getType())) { - Location loc = operand.getLoc(); - operand = rewriter.create(loc, operand); - } - } - rewriter.replaceOpWithNewOp(op, operands); - return success(); - } -}; - -struct ReduceOpToGenericConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (isUnsupported(op)) { - return rewriter.notifyMatchFailure(op, - "unsupported reduce (noop or empty)"); - } - - Location loc = op.getLoc(); - - int numOperands = static_cast(adaptor.getInputs().size()); - - if (llvm::any_of(adaptor.getInputs(), [](Value v) { - return !isa(v.getType()); - })) { - return rewriter.notifyMatchFailure(op, "expects known-rank args"); - } - auto srcRank = cast(adaptor.getInputs()[0].getType()).getRank(); - - SmallVector reductionDims = llvm::to_vector(op.getDimensions()); - - SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - - SmallVector outputs; - SmallVector indexingMaps; - for (auto [operand, initValue, resultType] : llvm::zip_equal( - adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) { - // Check if init_value is constant. If so, inline the value into the - // region. - initValue = rewriter.createOrFold(loc, initValue); - - SmallVector dynShape = getReduceOpEmptyTensorDynSizes( - rewriter, loc, operand, cast(resultType), reductionDims); - auto emptyTensor = - getEmptyTensor(rewriter, loc, cast(resultType), dynShape); - Value filledTensor = - rewriter.create(loc, initValue, emptyTensor).result(); - outputs.push_back(filledTensor); - } - - // Prepare indexing maps for linalg generic op. The elements are for src - // and dst. Transpose `src` to make the reduction loops be the innermost, - // because it's easier to fully utilize processors. - indexingMaps.append(numOperands, - getTransposeMapForReduction(rewriter.getContext(), - static_cast(srcRank), - reductionDims)); - - // The indexing map of `dst` should drop the reduction loops. Since the - // reduction loops now are all in the innermost, drops - // `reduction_dims.size()` dimensions. We don't need an inverse - // permutation here because they are the same. - SmallVector exprs; - for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i) { - exprs.push_back(rewriter.getAffineDimExpr(i)); - } - indexingMaps.append(numOperands, - AffineMap::get(srcRank, /*symbolCount=*/0, exprs, - rewriter.getContext())); - - auto linalgOp = rewriter.create( - loc, /*resultTensorTypes=*/resultTypes, adaptor.getInputs(), - /*outputBuffers=*/ValueRange{outputs}, indexingMaps, - getParallelAndReductionIterators(srcRank, reductionDims.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. The reduce op region apply function - // has a signature (lhs, rhs) -> output, all of the same tensor type t. - // This is converted to a function with the same signature but with - // element types. E.g., "(tensor, tensor) -> tensor" will - // be converted to "(f32, f32, f32)". - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getBody(), region, region.end()); - TypeConverter::SignatureConversion signatureConverter(numOperands * 2); - - // The stablehlo ReduceOp requires that the seed be used as a LHS operand - // inside the region, and the seed is encoded in linalg in the initial out - // value, so modify the signature of the block and the value mappings, so - // the output args will correlate with the original LHS and the inputs - // correlate with the original RHS. - for (auto [idx, val] : llvm::enumerate(op.getInputs())) { - signatureConverter.addInputs( - /*origInputNo=*/idx + numOperands, - // type for the new operand number 'idx'. - typeConverter->convertType( - cast(val.getType()).getElementType())); - } - for (auto [idx, val] : llvm::enumerate(op.getInitValues())) { - signatureConverter.addInputs( - /*origInputNo=*/idx, - // type for the new operand number 'idx' + 'numOperands'. - typeConverter->convertType( - cast(val.getType()).getElementType())); - } - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct ReduceOpToReduceConverter final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (isUnsupported(op)) { - return rewriter.notifyMatchFailure(op, - "unsupported reduce (noop or empty)"); - } - - auto reductionDims = llvm::to_vector(op.getDimensions()); - // stablehlo.reduce doesn't specify the order of the reduction dimensions. - llvm::sort(reductionDims); - - auto toRankedTensor = [](Value v) -> RankedTensorType { - return dyn_cast(v.getType()); - }; - - SmallVector outputs; - SmallVector operandTypes, initTypes; - SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - - Location loc = op.getLoc(); - for (auto [operand, initValue, resultType] : llvm::zip_equal( - adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) { - auto initType = toRankedTensor(initValue); - if (!initType) - return rewriter.notifyMatchFailure(op, - "expects known-rank init values"); - initTypes.push_back(initType); - auto operandType = toRankedTensor(operand); - if (!operandType) - return rewriter.notifyMatchFailure(op, "expects known-rank operands"); - operandTypes.push_back(operandType); - initValue = rewriter.createOrFold(loc, initValue); - auto tensorResultType = cast(resultType); - // For linalg.reduce, the result type's dimensions must match the input's - // dimensions, whereas StableHLO allows replacing static dimensions with - // dynamic ones. - SmallVector resultShape; - SmallVector dynShape; - for (auto [index, dim] : - llvm::enumerate(cast(operand.getType()).getShape())) { - if (!llvm::is_contained(reductionDims, index)) { - resultShape.push_back(dim); - if (ShapedType::isDynamic(dim)) { - dynShape.push_back( - rewriter.create(loc, operand, index)); - } - } - } - - Value emptyTensor = rewriter.create( - loc, resultShape, tensorResultType.getElementType(), dynShape); - Value filledTensor = - rewriter.create(loc, initValue, emptyTensor).result(); - outputs.push_back(filledTensor); - } - - auto linalgOp = rewriter.create( - loc, adaptor.getInputs(), outputs, reductionDims, - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - Region ®ion = linalgOp.getRegion(); - rewriter.inlineRegionBefore(op.getBody(), region, region.end()); - - // Convert the signature of the body. The reduce op 'computation' region - // apply function has a signature with tensor types, this is converted to a - // function with element types. E.g. the signature "(tensor, - // tensor) -> tensor" will be converted to "(f32, f32) -> f32". - // Also, we need to swap the operands of the function. The stablehlo.reduce - // op expects the init values to be the first parameters of the apply - // function, while the linalg.reduction op expects the init values as the - // last parameters of the 'combiner' region apply function. - TypeConverter::SignatureConversion signatureConverter( - linalgOp.getNumDpsInputs() * 2); - assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits()); - for (const auto &[idx, val] : llvm::enumerate(operandTypes)) { - signatureConverter.addInputs( - /*origInputNo=*/idx + linalgOp.getNumDpsInputs(), - // type for new operand number 'idx'. - typeConverter->convertType(val.getElementType())); - } - for (const auto &[idx, val] : llvm::enumerate(initTypes)) { - signatureConverter.addInputs( - /*origInputNo=*/idx, - // type for new operand number 'idx' + linalgOp.getNumInputs() - typeConverter->convertType(val.getElementType())); - } - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - - // Cast the result to the correct type. - SmallVector results; - for (auto [result, resultType] : - llvm::zip(linalgOp.getResults(), resultTypes)) { - results.push_back( - rewriter.createOrFold(loc, resultType, result)); - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -struct ReduceWindowOpOnTensorsGenericConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = op->getContext(); - Location loc = op.getLoc(); - llvm::SmallVector initValues = adaptor.getInitValues(); - llvm::SmallVector resultTypes; - if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) - return failure(); - auto numOperands = initValues.size(); - - llvm::SmallVector windowDimensions(op.getWindowDimensions()); - - llvm::SmallVector padding; - if (op.getPadding()) { - padding = extract1DVector(*op.getPadding()); - } - - llvm::SmallVector baseDilations; - if (op.getBaseDilations()) { - baseDilations = llvm::to_vector(*op.getBaseDilations()); - } - - llvm::SmallVector windowStrides(windowDimensions.size(), 1); - if (op.getWindowStrides()) { - windowStrides = llvm::to_vector(*op.getWindowStrides()); - } - - llvm::SmallVector windowDilations(windowDimensions.size(), 1); - if (op.getWindowDilations()) { - windowDilations = llvm::to_vector(*op.getWindowDilations()); - } - - auto rank = static_cast(windowDimensions.size()); - SmallVector srcExprs; - SmallVector windowExprs; - SmallVector dstExprs; - SmallVector filteredWindowDims; - - int windowDim = 0; - for (int64_t i = 0; i < rank; i++) { - AffineExpr srcExpr = mlir::getAffineDimExpr(i, ctx); - - if (windowStrides[i] != 1) - srcExpr = srcExpr * windowStrides[i]; - - if (windowDimensions[i] != 1) { - filteredWindowDims.push_back(windowDimensions[i]); - AffineExpr windowExpr = mlir::getAffineDimExpr(rank + windowDim, ctx); - windowExprs.push_back(windowExpr); - - if (windowDilations[i] != 1) - windowExpr = windowExpr * windowDilations[i]; - - srcExpr = srcExpr + windowExpr; - windowDim++; - } - - srcExprs.push_back(srcExpr); - dstExprs.push_back(mlir::getAffineDimExpr(i, ctx)); - } - - SmallVector inferredMaps(3, AffineMap::get(ctx)); - if (rank > 0) { - inferredMaps = - AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs}, ctx); - } - - SmallVector indexingMaps; - - indexingMaps.append(numOperands, inferredMaps[0]); - indexingMaps.append(1, inferredMaps[1]); - indexingMaps.append(numOperands, inferredMaps[2]); - - // Setup the initial values. - llvm::SmallVector broadcastValues; - for (uint64_t i = 0, s = initValues.size(); i < s; i++) { - Value initValue = initValues[i]; - auto resultTy = llvm::cast(resultTypes[i]); - if (!resultTy.hasStaticShape()) - return failure(); - - auto broadcastSizes = rewriter.getDenseI64ArrayAttr(resultTy.getShape()); - broadcastValues.push_back(rewriter.create( - loc, resultTy, initValue, broadcastSizes)); - } - - llvm::SmallVector inputs = llvm::to_vector(adaptor.getInputs()); - - // Pad as necessary. - if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) || - llvm::any_of(baseDilations, [](int64_t v) { return v != 1; })) { - llvm::SmallVector staticLows(rank, 0); - llvm::SmallVector staticHighs(rank, 0); - for (int64_t i = 0; i < static_cast(padding.size()); i += 2) { - staticLows[i / 2] = padding[i]; - staticHighs[i / 2] = padding[i + 1]; - } - // Translate base dilation into interior padding. - llvm::SmallVector staticInteriors(rank, 0); - for (auto [idx, dilation] : llvm::enumerate(baseDilations)) { - staticInteriors[idx] = dilation - 1; - } - - for (auto [input, initValue] : llvm::zip(inputs, initValues)) { - input = rewriter.create( - loc, input, initValue, staticLows, staticHighs, staticInteriors); - } - } - - // Add the extra input for the reduction dimension. - inputs.push_back(rewriter.create(loc, filteredWindowDims, - rewriter.getF32Type())); - - auto linalgOp = rewriter.create( - loc, /*resultTensors=*/resultTypes, - /*inputs=*/inputs, - /*outputs=*/broadcastValues, indexingMaps, - getParallelAndReductionIterators(rank + filteredWindowDims.size(), - filteredWindowDims.size()), - /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); - - // Convert the signature of the body. This includes converting scalar - // tensors to their scalar values and inserting an additional block arg for - // the window arg. - Region ®ion = linalgOp.getRegion(); - rewriter.cloneRegionBefore(op.getBody(), region, region.end()); - - TypeConverter::SignatureConversion signatureConverter( - inputs.size() + op->getNumResults() - 1); - - // ReduceWindow requires that the seed be used as a LHS operand inside the - // region, and the seed is encoded in linalg in the initial out value, so - // modify the signature of the block and the value mappings, so the output - // args will correlate with the LHS and the inputs correlate with the RHS. - for (auto [i, type] : llvm::enumerate(resultTypes)) { - auto idx = inputs.size() + i - 1; - signatureConverter.addInputs(idx, - cast(type).getElementType()); - } - - signatureConverter.addInputs( - cast(inputs.back().getType()).getElementType()); - - for (auto [i, input] : - llvm::enumerate(ArrayRef(inputs).drop_back())) { - signatureConverter.addInputs( - i, cast(input.getType()).getElementType()); - } - - rewriter.applySignatureConversion(®ion.front(), signatureConverter, - getTypeConverter()); - rewriter.replaceOp(op, linalgOp.getResults()); - return success(); - } -}; - -struct ReduceWindowOpConversion final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - /// Get the operation used for reduction applied to `result_index`th result. - /// Its expected to be a binary operation that consumes `result_index`th and - /// `result_index + getInputs().size`th arguments of the body. - static Operation *getReductionOp(mlir::stablehlo::ReduceWindowOp op, - int resultIndex) { - auto returnOp = - cast(op.getBody().front().getTerminator()); - Operation *computeOp = returnOp.getResults()[resultIndex].getDefiningOp(); - if (computeOp->getNumOperands() != 2) - return nullptr; - auto arg0 = llvm::dyn_cast(computeOp->getOperand(0)); - auto arg1 = llvm::dyn_cast(computeOp->getOperand(1)); - if (!arg0 || !arg1) - return nullptr; - int64_t arg0Num = arg0.getArgNumber(); - int64_t arg1Num = arg1.getArgNumber(); - int64_t otherArgIndex = resultIndex + op.getInputs().size(); - if (arg0Num == resultIndex && arg1Num == otherArgIndex) - return computeOp; - if (arg0Num == otherArgIndex && arg1Num == resultIndex && - computeOp->hasTrait()) - return computeOp; - return nullptr; - } - - /// stablehlo.reduce_window is mapped to a linalg.pooling operation. The type - /// of the pooling is determined based on the body of the reduce window - /// operation. This class enumerates the different variants. - enum class PoolingType { - kInvalid, - k2DMin, - k3DMin, - k2DMax, - k3DMax, - k2DAdd, - k3DAdd, - }; - - static PoolingType getPoolingType(mlir::stablehlo::ReduceWindowOp reduceOp, - int resultIndex) { - auto rank = llvm::cast(reduceOp.getResultTypes()[resultIndex]) - .getRank(); - if (Operation *op = getReductionOp(reduceOp, resultIndex)) { - if (isa(*op) && rank == 4) - return PoolingType::k2DMin; - if (isa(*op) && rank == 5) - return PoolingType::k3DMin; - if (isa(*op) && rank == 4) - return PoolingType::k2DMax; - if (isa(*op) && rank == 5) - return PoolingType::k3DMax; - if (isa(*op) && rank == 4) - return PoolingType::k2DAdd; - if (isa(*op) && rank == 5) - return PoolingType::k3DAdd; - } - return PoolingType::kInvalid; - } - - LogicalResult - matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - int rank = llvm::cast(op.getResultTypes()[0]).getRank(); - if (rank != 4 && rank != 5) { - return rewriter.notifyMatchFailure( - op, "expected NHWC/NDHWC pooling-based op"); - } - - if (op.getPadding() && !isSplatValue(*op.getPadding(), 0)) { - return rewriter.notifyMatchFailure(op, "require paddings are all zero"); - } - - if (op.getBaseDilations() && !isSplatValue(*op.getBaseDilations(), 1)) { - return rewriter.notifyMatchFailure(op, "expected undilated base"); - } - - int lastDim = rank - 1; - SmallVector fakeWindowShapes; - for (int i = 1; i < lastDim; ++i) { - fakeWindowShapes.push_back(op.getWindowDimensions()[i]); - } - - if (op.getWindowStrides() && - (op.getWindowStrides().value()[0] != 1 || - op.getWindowStrides().value()[lastDim] != 1)) { - return rewriter.notifyMatchFailure( - op, "expected window_strides to be [1,x,y,(z),1]"); - } - if (op.getWindowDimensions()[0] != 1 || - op.getWindowDimensions()[lastDim] != 1) { - return rewriter.notifyMatchFailure( - op, "expected window_dimensions to be [1,x,y,(z),1]"); - } - - Attribute strides; - SmallVector vec; - if (op.getWindowStridesAttr()) { - for (int i = 1; i < lastDim; ++i) { - vec.push_back(op.getWindowStrides().value()[i]); - } - } else { - vec.assign(rank - 2, 1); - } - strides = rewriter.getI64VectorAttr(vec); - - Attribute dilations; - vec.clear(); - if (op.getWindowDilations()) { - for (int i = 1; i < lastDim; ++i) { - vec.push_back(op.getWindowDilations().value()[i]); - } - } else { - vec.assign(rank - 2, 1); - } - dilations = rewriter.getI64VectorAttr(vec); - - SmallVector poolingOps; - - ValueRange operands = adaptor.getInputs(); - ValueRange initValues = adaptor.getInitValues(); - for (auto it : llvm::zip(op.getResults(), operands, initValues)) { - OpResult result = std::get<0>(it); - Value input = std::get<1>(it); - Value initValue = std::get<2>(it); - auto resultType = cast(result.getType()); - if (!cast(input.getType()).getElementType().isF32()) { - return rewriter.notifyMatchFailure(op, - "expected element type to be f32"); - } - - // Create a fake window dimension. - auto fakeWindowDims = rewriter.create( - loc, fakeWindowShapes, resultType.getElementType()); - - SmallVector resultDynamicDims; - for (const auto &en : llvm::enumerate(resultType.getShape())) { - if (!ShapedType::isDynamic(en.value())) - continue; - Value dimSize = rewriter.create(loc, input, en.index()); - if (en.index() == 0 || static_cast(en.index()) == rank - 1) { - // batch dims and channel dims can be derived from input dims - // directly. - resultDynamicDims.push_back(dimSize); - } else { - auto i = en.index() - 1; - auto stride = - llvm::cast(strides).getValues()[i]; - auto dilation = llvm::cast(dilations) - .getValues()[i]; - // let j = i * stride - // output[i] = reduce( input[j, j + window_size * dilation) ) - Value offset = rewriter.create( - loc, fakeWindowShapes[i] * dilation); - dimSize = rewriter.create(loc, dimSize, offset); - dimSize = rewriter.create( - loc, dimSize, - rewriter.create(loc, stride)); - dimSize = rewriter.create( - loc, dimSize, rewriter.create(loc, 1)); - resultDynamicDims.push_back(dimSize); - } - } - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), - resultDynamicDims); - - initValue = rewriter.create(loc, initValue); - Value filledInitTensor = - rewriter.create(loc, initValue, emptyTensor) - .getResult(0); - auto createOp = [&](auto *typePtr) -> linalg::LinalgOp { - return cast( - rewriter - .create>( - loc, ArrayRef{resultType}, - ValueRange{input, fakeWindowDims.getResult()}, - filledInitTensor, strides, dilations, - linalg::getPrunedAttributeList(op)) - .getOperation()); - }; - linalg::LinalgOp poolingOp; - PoolingType poolingType = getPoolingType(op, result.getResultNumber()); - switch (poolingType) { - case PoolingType::k2DMin: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k3DMin: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k2DMax: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k3DMax: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k2DAdd: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::k3DAdd: { - poolingOp = createOp(static_cast(nullptr)); - break; - } - case PoolingType::kInvalid: - return rewriter.notifyMatchFailure(op, "unknown reduction operation"); - } - poolingOps.push_back(poolingOp->getResult(0)); - } - rewriter.replaceOp(op, poolingOps); - return success(); - } -}; - -} // namespace - -namespace detail { -void populateStableHloReductionToLinalgConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns, bool enablePrimitiveOps) { - if (enablePrimitiveOps) { - patterns->add(typeConverter, context); - } else { - patterns->add(typeConverter, context); - } - patterns->add(typeConverter, - context); - - // Ensure specialized patterns are higher priority than their generic - // versions. - patterns->add(typeConverter, context, - PatternBenefit(2)); -} -} // namespace detail -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp b/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp deleted file mode 100644 index 30f0bf28cf4f..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h" - -#include - -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" - -namespace mlir::iree_compiler::stablehlo { - -namespace { - -Type convertInteger(IntegerType intType) { - return IntegerType::get(intType.getContext(), - intType.getIntOrFloatBitWidth()); -} - -Type convertShapedType(ShapedType shapedType) { - if (auto intType = llvm::dyn_cast(shapedType.getElementType())) - return shapedType.clone(convertInteger(intType)); - return shapedType; -} - -Value materializeCastFromIllegal(OpBuilder &builder, Type type, - ValueRange inputs, Location loc) { - Type fromType = getElementTypeOrSelf(inputs[0].getType()); - Type toType = getElementTypeOrSelf(type); - if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || - !toType.isSignlessInteger()) - return Value(); - // Use unrealized conversion casts to do signful->signless conversions. - return builder.create(loc, type, inputs[0]) - ->getResult(0); -} - -Value materializeCastToIllegal(OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - Type fromType = getElementTypeOrSelf(inputs[0].getType()); - Type toType = getElementTypeOrSelf(type); - if (!fromType.isSignlessInteger() || - (!toType.isSignedInteger() && !toType.isUnsignedInteger())) - return Value(); - // Use unrealized conversion casts to do signless->signful conversions. - return builder.create(loc, type, inputs[0]) - ->getResult(0); -} - -Value scalarToTensor(OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - assert(inputs.size() == 1); - if (llvm::isa(inputs.front().getType())) { - return Value(); - } - auto tensor = - builder - .create( - loc, RankedTensorType::get({}, inputs.front().getType()), - inputs.front()) - .getResult(); - return builder.create(loc, type, tensor) - .getResult(0); -} - -} // namespace - -RemoveSignTypeConverter::RemoveSignTypeConverter() { - addConversion([](Type type) { return type; }); - - addConversion(convertInteger); - addConversion(convertShapedType); - - addArgumentMaterialization(materializeCastToIllegal); - addSourceMaterialization(materializeCastToIllegal); - addTargetMaterialization(materializeCastFromIllegal); -} - -LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() { - addArgumentMaterialization(scalarToTensor); - addSourceMaterialization(scalarToTensor); - addTargetMaterialization(scalarToTensor); -} - -} // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.h b/compiler/plugins/input/StableHLO/Conversion/TypeConversion.h deleted file mode 100644 index d7b4ce03d0dc..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/TypeConversion.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_TYPE_CONVERSION_H -#define IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_TYPE_CONVERSION_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir::iree_compiler::stablehlo { - -// Type converter to use as part of lowerings from dialects that carry signs -// in their types to those that are signless. -class RemoveSignTypeConverter : public TypeConverter { -public: - RemoveSignTypeConverter(); -}; - -// Type converter which adds additional materializations (beyond signless) -// that are needed as part of the HloToLinalg conversion patterns. -// This is the type converter used by the test pass and is the sanctioned -// way to use the underlying patterns. -class LinalgTypeConverter : public RemoveSignTypeConverter { -public: - LinalgTypeConverter(); -}; - -} // namespace mlir::iree_compiler::stablehlo - -#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_TYPE_CONVERSION_H diff --git a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel index 774e14f75eec..c8f2420d6ae4 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel @@ -25,14 +25,7 @@ iree_lit_test_suite( "legalize_shape_computations.mlir", "stablehlo_custom_calls.mlir", "stablehlo_to_iree_input_dialects.mlir", - "stablehlo_to_linalg_convolution.mlir", - "stablehlo_to_linalg_dot_prod.mlir", "stablehlo_to_linalg_ext.mlir", - "stablehlo_to_linalg_gather.mlir", - "stablehlo_to_linalg_pointwise.mlir", - "stablehlo_to_linalg_random.mlir", - "stablehlo_to_linalg_reduce.mlir", - "stablehlo_to_linalg.mlir", "verify_compiler_input_legality.mlir", "vhlo_stablehlo_mix_invalid.mlir", ], diff --git a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt index 6d9c9b1714d5..5f7202150f84 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt +++ b/compiler/plugins/input/StableHLO/Conversion/test/CMakeLists.txt @@ -23,14 +23,7 @@ iree_lit_test_suite( "legalize_shape_computations.mlir" "stablehlo_custom_calls.mlir" "stablehlo_to_iree_input_dialects.mlir" - "stablehlo_to_linalg.mlir" - "stablehlo_to_linalg_convolution.mlir" - "stablehlo_to_linalg_dot_prod.mlir" "stablehlo_to_linalg_ext.mlir" - "stablehlo_to_linalg_gather.mlir" - "stablehlo_to_linalg_pointwise.mlir" - "stablehlo_to_linalg_random.mlir" - "stablehlo_to_linalg_reduce.mlir" "verify_compiler_input_legality.mlir" "vhlo_stablehlo_mix_invalid.mlir" TOOLS diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg.mlir deleted file mode 100644 index cd8a2b0f6de2..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg.mlir +++ /dev/null @@ -1,1618 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// RUN: iree-opt %s --iree-stablehlo-to-linalg="enable-primitive-ops=true" \ -// RUN: --split-input-file --canonicalize | \ -// RUN: FileCheck %s --check-prefix=CHECK-PRIMITIVE - -// CHECK-LABEL: func @bitcast_convert -func.func @bitcast_convert(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> - func.return %result : tensor<2x2xf32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK-LABEL: func @bitcast_convert_dynamic -func.func @bitcast_convert_dynamic(%input: tensor) -> tensor { - %result = "stablehlo.bitcast_convert"(%input) : (tensor) -> tensor - func.return %result : tensor -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @bitcast_convert_expand -func.func @bitcast_convert_expand(%input: tensor<6xi32>) -> tensor<6x4xi8> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<6xi32>) -> tensor<6x4xi8> - func.return %result : tensor<6x4xi8> -} - -// CHECK: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: tensor.empty() : tensor<6x4xi8> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "parallel"]} -// CHECK: ^bb0(%[[IN:.*]]: i32, %[[OUT:.*]]: i8): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i32 -// CHECK: %[[SHIFT:.*]] = arith.shrui %[[IN]], %[[AMT]] : i32 -// CHECK: %[[TRUNC:.*]] = arith.trunci %[[SHIFT]] : i32 to i8 -// CHECK: linalg.yield %[[TRUNC]] : i8 -// CHECK: } -> tensor<6x4xi8> -// CHECK: return %[[RESULT]] : tensor<6x4xi8> - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @bitcast_convert_contract -func.func @bitcast_convert_contract(%input: tensor<7x4xi8>) -> tensor<7xi32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<7x4xi8>) -> tensor<7xi32> - func.return %result : tensor<7xi32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<7xi32> -// CHECK: linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<7xi32>) -> tensor<7xi32> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "reduction"]} -// CHECK: ^bb0(%[[IN:.*]]: i8, %[[OUT:.*]]: i32): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i3 -// CHECK: %[[EXT:.*]] = arith.extui %[[IN]] : i8 to i32 -// CHECK: %[[SHIFT:.*]] = arith.shli %[[EXT]], %[[AMT]] : i32 -// CHECK: %[[OR:.*]] = arith.ori %[[SHIFT]], %[[OUT]] : i32 -// CHECK: linalg.yield %[[OR]] : i32 -// CHECK: } -> tensor<7xi32> -// CHECK: return %[[RESULT]] : tensor<7xi32> - -// ----- - -// CHECK-LABEL: func @concatenate( -// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor -// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[C1]] : tensor -// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[C1]] : tensor -// CHECK: %[[VAL_15:.*]] = tensor.dim %[[VAL_2]], %[[C1]] : tensor -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_7]], %[[VAL_9]] : index -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_15]] : index -// CHECK: %[[VAL_23:.*]] = tensor.empty(%[[VAL_5]], %[[VAL_17]]) : tensor -// CHECK: %[[VAL_24:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_23]] : tensor) { -// CHECK: ^bb0(%[[VAL_25:.*]]: i32): -// CHECK: %[[VAL_26:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_28:.*]] = linalg.index 1 : index -// CHECK: %[[VAL_30:.*]] = tensor.dim %[[VAL_0]], %[[C1]] : tensor -// CHECK: %[[VAL_32:.*]] = arith.cmpi ult, %[[VAL_28]], %[[VAL_30]] : index -// CHECK: %[[VAL_33:.*]] = scf.if %[[VAL_32]] -> (i32) { -// CHECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_0]][%[[VAL_26]], %[[VAL_28]]] : tensor -// CHECK: scf.yield %[[VAL_35]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_37:.*]] = tensor.dim %[[VAL_1]], %[[C1]] : tensor -// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_30]], %[[VAL_37]] : index -// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_28]], %[[VAL_38]] : index -// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (i32) { -// CHECK: %[[VAL_41:.*]] = arith.subi %[[VAL_28]], %[[VAL_30]] : index -// CHECK: %[[VAL_42:.*]] = tensor.extract %[[VAL_1]][%[[VAL_26]], %[[VAL_41]]] : tensor -// CHECK: scf.yield %[[VAL_42]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_43:.*]] = arith.subi %[[VAL_28]], %[[VAL_38]] : index -// CHECK: %[[VAL_44:.*]] = tensor.extract %[[VAL_2]][%[[VAL_26]], %[[VAL_43]]] : tensor -// CHECK: scf.yield %[[VAL_44]] : i32 -// CHECK: } -// CHECK: scf.yield %[[VAL_45:.*]] : i32 -// CHECK: } -// CHECK: linalg.yield %[[VAL_46:.*]] : i32 -// CHECK: } -> tensor -// CHECK: return %[[VAL_47:.*]] : tensor -// CHECK: } -func.func @concatenate(%a: tensor, %b: tensor, %c: tensor) -> tensor { - %concat = "stablehlo.concatenate"(%a, %b, %c) { - dimension = 1 - } : (tensor, tensor, tensor) -> tensor - func.return %concat : tensor -} - -// ----- - -// CHECK-LABEL: func @concatenate_unsigned( -// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : tensor to tensor -// CHECK-DAG: %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : tensor to tensor -// CHECK-DAG: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : tensor to tensor -// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_3]], %[[C0]] : tensor -// CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_3]], %[[C1]] : tensor -// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_4]], %[[C1]] : tensor -// CHECK: %[[VAL_18:.*]] = tensor.dim %[[VAL_5]], %[[C1]] : tensor -// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_10]], %[[VAL_14]] : index -// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index -// CHECK: %[[VAL_26:.*]] = tensor.empty(%[[VAL_8]], %[[VAL_20]]) : tensor -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_26]] : tensor) { -// CHECK: ^bb0(%[[VAL_28:.*]]: i32): -// CHECK: %[[VAL_29:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_30:.*]] = linalg.index 1 : index -// CHECK: %[[VAL_33:.*]] = tensor.dim %[[VAL_3]], %[[C1]] : tensor -// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_33]] : index -// CHECK: %[[VAL_36:.*]] = scf.if %[[VAL_35]] -> (i32) { -// CHECK: %[[VAL_38:.*]] = tensor.extract %[[VAL_3]][%[[VAL_29]], %[[VAL_30]]] : tensor -// CHECK: scf.yield %[[VAL_38]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_40:.*]] = tensor.dim %[[VAL_4]], %[[C1]] : tensor -// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_33]], %[[VAL_40]] : index -// CHECK: %[[VAL_42:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_41]] : index -// CHECK: %[[VAL_43:.*]] = scf.if %[[VAL_42]] -> (i32) { -// CHECK: %[[VAL_44:.*]] = arith.subi %[[VAL_30]], %[[VAL_33]] : index -// CHECK: %[[VAL_45:.*]] = tensor.extract %[[VAL_4]][%[[VAL_29]], %[[VAL_44]]] : tensor -// CHECK: scf.yield %[[VAL_45]] : i32 -// CHECK: } else { -// CHECK: %[[VAL_46:.*]] = arith.subi %[[VAL_30]], %[[VAL_41]] : index -// CHECK: %[[VAL_47:.*]] = tensor.extract %[[VAL_5]][%[[VAL_29]], %[[VAL_46]]] : tensor -// CHECK: scf.yield %[[VAL_47]] : i32 -// CHECK: } -// CHECK: scf.yield %[[VAL_48:.*]] : i32 -// CHECK: } -// CHECK: linalg.yield %[[VAL_49:.*]] : i32 -// CHECK: } -> tensor -// CHECK: %[[VAL_50:.*]] = builtin.unrealized_conversion_cast %[[VAL_51:.*]] : tensor to tensor -// CHECK: return %[[VAL_50]] : tensor -// CHECK: } -func.func @concatenate_unsigned(%a: tensor, %b: tensor, %c: tensor) -> tensor { - %concat = "stablehlo.concatenate"(%a, %b, %c) { - dimension = 1 - } : (tensor, tensor, tensor) -> tensor - func.return %concat : tensor -} - -// ----- - -// CHECK-LABEL: func @constant -// CHECK: %[[CONSTANT:.*]] = arith.constant dense<10> : tensor -func.func @constant() -> tensor { - %result = "stablehlo.constant"() { - value = dense<10> : tensor - } : () -> (tensor) - func.return %result : tensor -} - -// ----- - -// CHECK-LABEL: func @elided_constant -// CHECK: %[[CONSTANT:.*]] = arith.constant dense_resource<__elided__> : tensor<1024xf32> -func.func @elided_constant() -> tensor<1024xf32> { - %result = "stablehlo.constant"() { - value = dense_resource<__elided__> : tensor<1024xf32> - } : () -> (tensor<1024xf32>) - func.return %result : tensor<1024xf32> -} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_basic -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x5x6xf32>) -func.func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "ijk,ikm->ijm", someattr}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> - func.return %0 : tensor<3x4x6xf32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<3x4x6xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5xf32>, tensor<3x5x6xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<3x4x6xf32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @einsum_pointwisemul -func.func @einsum_pointwisemul(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "abc,abc->abc"} : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> - func.return %0 : tensor<3x4x5xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x4x5xf32>) -// CHECK: tensor.empty() : tensor<3x4x5xf32> -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5xf32>, tensor<3x4x5xf32>) -// CHECK-SAME: outs(%[[DST:.+]] : tensor<3x4x5xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[RES:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: func @einsum_matmul -func.func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tensor<7x5xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "ae,ed->ad"}: (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> - func.return %0 : tensor<7x5xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<7x9xf32>, %[[RHS:.*]]: tensor<9x5xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<7x5xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<7x9xf32>, tensor<9x5xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<7x5xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d5)> -// CHECK: func @einsum_broadcast4 -func.func @einsum_broadcast4(%arg0: tensor<3x4x5x6x7xf32>, %arg1: tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "abcdh,hg->abcdg"}: (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> - func.return %0 : tensor<3x4x5x6x8xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5x6x7xf32>, %[[RHS:.*]]: tensor<7x8xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<3x4x5x6x8xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<3x4x5x6x8xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_ellipsis -func.func @einsum_ellipsis(%arg0: tensor<1x512x128xf32>, %arg1: tensor<128x256xf32>) -> tensor<1x512x256xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "...x,xy->...y"} : (tensor<1x512x128xf32>, tensor<128x256xf32>) -> tensor<1x512x256xf32> - func.return %0 : tensor<1x512x256xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<1x512x128xf32>, %[[RHS:.*]]: tensor<128x256xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<1x512x256xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<1x512x128xf32>, tensor<128x256xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x512x256xf32>) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_dynamic_size_broadcast_dot -func.func @einsum_dynamic_size_broadcast_dot(%arg0: tensor, %arg1: tensor<4x?xf32>) -> tensor { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "abc,cd->abd"} : (tensor, tensor<4x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor<4x?xf32>) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[LHS]], %[[C0]] : tensor -// CHECK: %[[DIM1:.+]] = tensor.dim %[[LHS]], %[[C1]] : tensor -// CHECK: %[[DIM2:.+]] = tensor.dim %[[RHS]], %[[C1:.+]] : tensor<4x?xf32> -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]]) : tensor -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor, tensor<4x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK: func @broadcast_in_dim -func.func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> - func.return %0 : tensor<7x10x6x4x5xf32> -} -// CHECK: tensor.empty() : tensor<7x10x6x4x5xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim -// CHECK-PRIMITIVE: tensor.collapse_shape -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE: permutation = [1, 0] -// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6x4x5xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [1, 2, 3] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK: func @broadcast_in_dim_ui32 -func.func @broadcast_in_dim_ui32(%operand: tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32> - func.return %0 : tensor<7x10x6x4x5xui32> -} -// CHECK: builtin.unrealized_conversion_cast %{{.*}} : tensor<5x7x1xui32> to tensor<5x7x1xi32> -// CHECK: tensor.empty() : tensor<7x10x6x4x5xi32> -// CHECK: %[[RES:.*]] = linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: i32, %{{.*}}: i32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : i32 -// CHECK: builtin.unrealized_conversion_cast %[[RES]] : tensor<7x10x6x4x5xi32> to tensor<7x10x6x4x5xui32> - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_ui32 -// CHECK-PRIMITIVE: tensor.collapse_shape -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE: permutation = [1, 0] -// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6x4x5xi32> -// CHECK-PRIMITIVE: %[[RES:.*]] = linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [1, 2, 3] -// CHECK-PRIMITIVE: builtin.unrealized_conversion_cast %[[RES]] : tensor<7x10x6x4x5xi32> to tensor<7x10x6x4x5xui32> - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @broadcast_in_dim_with_one_to_one -func.func @broadcast_in_dim_with_one_to_one( - %operand: tensor<1xf32>) -> tensor<1x5xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<1xf32>) -> tensor<1x5xf32> - func.return %0 : tensor<1x5xf32> -} -// CHECK: tensor.empty() : tensor<1x5xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_with_one_to_one -// CHECK-PRIMITIVE-NOT: tensor.collapse_shape -// CHECK-PRIMITIVE-NOT: linalg.transpose -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [1] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d1)> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @broadcast_in_dim_with_transpose -func.func @broadcast_in_dim_with_transpose( - %operand: tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> - func.return %0 : tensor<3x4x2x5xf32> -} -// CHECK: tensor.empty() : tensor<3x4x2x5xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_with_transpose -// CHECK-PRIMITIVE: tensor.empty() : tensor<3x4x2xf32> -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE: permutation = [1, 2, 0] -// CHECK-PRIMITIVE: tensor.empty() : tensor<3x4x2x5xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [3] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @broadcast_in_dim_scalar -func.func @broadcast_in_dim_scalar(%operand: tensor) -> tensor<7x10x6xf32> { - %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = array} - : (tensor) -> tensor<7x10x6xf32> - func.return %0 : tensor<7x10x6xf32> -} -// CHECK: tensor.empty() : tensor<7x10x6xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_scalar -// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [0, 1, 2] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @broadcast_scalar -func.func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { - %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array} : (tensor) -> tensor<4x2x1xf32> - func.return %0: tensor<4x2x1xf32> -} -// CHECK: tensor.empty() : tensor<4x2x1xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast_scalar -// CHECK-PRIMITIVE: tensor.empty() : tensor<4x2x1xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE-SAME: ins( -// CHECK-PRIMITIVE-SAME: outs( -// CHECK-PRIMITIVE-SAME: dimensions = [0, 1, 2] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> -// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: func @broadcast -func.func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { - %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> - func.return %0: tensor<4x2x1x4x?x16xf32> -} -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> -// CHECK: %{{.*}} = tensor.empty(%[[DIM]]) : tensor<4x2x1x4x?x16xf32> -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @broadcast -// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-PRIMITIVE: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> -// CHECK-PRIMITIVE: %{{.*}} = tensor.empty(%[[DIM]]) : tensor<4x2x1x4x?x16xf32> -// CHECK-PRIMITIVE: linalg.broadcast -// CHECK-PRIMITIVE: dimensions = [0, 1, 2] - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_f32 -func.func @iota_f32() -> tensor<7x10xf32> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64, someattr} : () -> (tensor<7x10xf32>) - func.return %result : tensor<7x10xf32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-SAME: {someattr} -// CHECK-NEXT: ^bb0(%{{.*}}: f32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @iota_f32 -// CHECK-PRIMITIVE: %[[EMPTY:.*]] = tensor.empty() -// CHECK-PRIMITIVE: linalg.map outs(%[[EMPTY]] : tensor<7x10xf32> -// CHECK-PRIMITIVE-SAME: {someattr} -// CHECK-PRIMITIVE: %[[INDEX:.*]] = linalg.index 1 -// CHECK-PRIMITIVE-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i64 -// CHECK-PRIMITIVE-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i64 to f32 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[FLOAT_CAST]] - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_i32 -func.func @iota_i32() -> tensor<7x10xi32> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xi32>) - func.return %result : tensor<7x10xi32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: i32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32 - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_ui32 -func.func @iota_ui32() -> tensor<7x10xui32> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xui32>) - func.return %result : tensor<7x10xui32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: i32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: linalg.yield %[[INT_CAST]] : i32 -// CHECK: builtin.unrealized_conversion_cast %{{.*}} : tensor<7x10xi32> to tensor<7x10xui32> - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @iota_complexf32 -func.func @iota_complexf32() -> tensor<7x10xcomplex> { - %result = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xcomplex>) - func.return %result : tensor<7x10xcomplex> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: complex): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i32 to f32 -// CHECK-NEXT: %[[COMPLEX_CAST:.*]] = complex.create %[[FLOAT_CAST]], %[[ZERO]] : complex -// CHECK-NEXT: linalg.yield %[[COMPLEX_CAST]] : complex - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @dynamic_iota_f32 -// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> -func.func @dynamic_iota_f32(%shape: tensor<3xi32>) -> tensor { - %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor) - func.return %result : tensor -} -// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] -// CHECK: %[[I1:.*]] = arith.index_cast %[[V1]] : i32 to index -// CHECK: %[[V2:.*]] = tensor.extract %[[SHAPE]][%c1] -// CHECK: %[[I2:.*]] = arith.index_cast %[[V2]] : i32 to index -// CHECK: tensor.empty(%[[I1]], %[[I2]]) : tensor -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: f32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: %[[FLOAT_CAST:.*]] = arith.sitofp %[[INT_CAST]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 - -// ----- - -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @dynamic_iota_ui32 -// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> -func.func @dynamic_iota_ui32(%shape: tensor<3xi32>) -> tensor { - %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor) - func.return %result : tensor -} -// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] -// CHECK: %[[I1:.*]] = arith.index_cast %[[V1]] : i32 to index -// CHECK: %[[V2:.*]] = tensor.extract %[[SHAPE]][%c1] -// CHECK: %[[I2:.*]] = arith.index_cast %[[V2]] : i32 to index -// CHECK: tensor.empty(%[[I1]], %[[I2]]) : tensor -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] -// CHECK-NEXT: ^bb0(%{{.*}}: i32): -// CHECK-NEXT: %[[INDEX:.*]] = linalg.index 1 -// CHECK-NEXT: %[[INT_CAST:.*]] = arith.index_cast %[[INDEX]] : index to i32 -// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : i32 -// CHECK: builtin.unrealized_conversion_cast %{{.*}} : tensor to tensor - -// ----- - -// CHECK-LABEL: @map_mixed -// CHECK-PRIMITIVE-LABEL: @map_mixed -func.func @map_mixed(%arg0: tensor, - %arg1: tensor<4xf32>) -> tensor { - %0 = "stablehlo.map"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} - : (tensor, tensor<4xf32>) -> tensor - func.return %0 : tensor -} - -// CHECK: linalg.generic -// CHECK: %[[ADD:.+]] = arith.addf -// CHECK: linalg.yield %[[ADD]] : f32 - -// CHECK-PRIMITIVE: linalg.map { arith.addf } - -// ----- - -// CHECK-LABEL: @map_one_arg -// CHECK-PRIMITIVE-LABEL: @map_one_arg -func.func @map_one_arg(%arg0: tensor) -> tensor { - %0 = "stablehlo.map"(%arg0) ({ - ^bb0(%arg2: tensor): - %1 = stablehlo.add %arg2, %arg2 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} - : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK: linalg.generic -// CHECK: %[[ADD:.+]] = arith.addf -// CHECK: linalg.yield %[[ADD]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: %[[ADD:.+]] = arith.addf -// CHECK-PRIMITIVE: linalg.yield %[[ADD]] : f32 - -// ----- - -// CHECK-LABEL: @map_compare -// CHECK-SAME: %[[ARG0:.*]]: tensor>, -// CHECK-SAME: %[[ARG1:.*]]: tensor>) -// CHECK-PRIMITIVE-LABEL: @map_compare -// CHECK-PRIMITIVE-SAME: %[[ARG0:.*]]: tensor>, -// CHECK-PRIMITIVE-SAME: %[[ARG1:.*]]: tensor>) -func.func @map_compare(%arg0: tensor>, - %arg1: tensor>) -> tensor { - %0 = "stablehlo.map"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor>, %arg3: tensor>): - %1 = stablehlo.real %arg2 : (tensor>) -> tensor - %2 = stablehlo.real %arg3 : (tensor>) -> tensor - %3 = "stablehlo.compare"(%1, %2) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - "stablehlo.return"(%3) : (tensor) -> () - }) {dimensions = array} - : (tensor>, tensor>) -> tensor - func.return %0 : tensor -} - -// CHECK: %[[INIT:.+]] = tensor.empty -// CHECK: %[[MAP:.+]] = linalg.generic -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-SAME: outs(%[[INIT]] : tensor) { -// CHECK: ^bb0(%[[A:.+]]: complex, %[[B:.+]]: complex, %{{.+}}: i1): -// CHECK: %[[RE1:.+]] = complex.re %[[A]] : complex -// CHECK: %[[RE2:.+]] = complex.re %[[B]] : complex -// CHECK: %[[CMP:.+]] = arith.cmpf oeq, %[[RE1]], %[[RE2]] : f32 -// CHECK: linalg.yield %[[CMP]] : i1 -// CHECK: } -// CHECK: return %[[MAP]] : tensor - -// CHECK-PRIMITIVE: %[[INIT:.+]] = tensor.empty -// CHECK-PRIMITIVE: %[[MAP:.+]] = linalg.map -// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) -// CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex, %[[B:.+]]: complex) { -// CHECK-PRIMITIVE: %[[RE1:.+]] = complex.re %[[A]] : complex -// CHECK-PRIMITIVE: %[[RE2:.+]] = complex.re %[[B]] : complex -// CHECK-PRIMITIVE: %[[CMP:.+]] = arith.cmpf oeq, %[[RE1]], %[[RE2]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[CMP]] : i1 -// CHECK-PRIMITIVE: } -// CHECK-PRIMITIVE: return %[[MAP]] : tensor - -// ----- - -func.func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { - %0 = arith.constant dense<0.0> : tensor - %1 = "stablehlo.pad"(%arg0, %0) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xf32>, tensor) -> tensor<18x12xf32> - func.return %1 : tensor<18x12xf32> -} -// CHECK-LABEL: func @pad_cst -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: tensor.pad %[[ARG0]] low[4, 5] high[2, 3] -// CHECK: tensor.yield %[[CST]] : f32 -// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32> - -// ----- - -func.func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf32> { - %0 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xf32>, tensor) -> tensor<18x12xf32> - func.return %0 : tensor<18x12xf32> -} -// CHECK-LABEL: func @pad_tensor -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: tensor.pad %[[ARG0]] low[4, 5] high[2, 3] -// CHECK: tensor.yield %[[PAD]] : f32 -// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32> - -// ----- - -func.func @pad_interior(%arg0: tensor<12x4xui32>, %arg1: tensor) -> tensor<29x15xui32> { - %0 = arith.constant dense<0> : tensor - %1 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xui32>, tensor) -> tensor<29x15xui32> - func.return %1 : tensor<29x15xui32> -} -// CHECK-LABEL: func @pad_interior -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[CAST0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<12x4xui32> to tensor<12x4xi32> -// CHECK-DAG: %[[CAST1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : tensor to tensor -// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[CAST1]][] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<29x15xi32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PAD]] : i32) outs(%[[INIT]] : tensor<29x15xi32>) -> tensor<29x15xi32> -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[CAST0]] into %[[FILL]][4, 5] [12, 4] [2, 2] : tensor<12x4xi32> into tensor<29x15xi32> - -// ----- - -func.func @pad_interior_negative(%arg0: tensor<12x4xui32>, %arg1: tensor) -> tensor<25x9xui32> { - %0 = arith.constant dense<0> : tensor - %1 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = array, - edge_padding_low = array, - interior_padding = array - } : (tensor<12x4xui32>, tensor) -> tensor<25x9xui32> - func.return %1 : tensor<25x9xui32> -} -// CHECK-LABEL: func @pad_interior_negative -// CHECK: %[[PAD:.*]] = tensor.insert_slice %{{.+}} into %{{.+}}[4, 0] [12, 4] [2, 2] : tensor<12x4xi32> into tensor<29x10xi32> -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD]][0, 1] [25, 9] [1, 1] : tensor<29x10xi32> to tensor<25x9xi32> - -// ----- - -// CHECK-LABEL: func @real_dynamic_slice -// CHECK-SAME: (%[[OPERAND:.*]]: tensor<256x?xf32>, %[[START_INDICES:.*]]: tensor<2xindex>, %[[LIMIT_INDICES:.*]]: tensor<2xindex>, %[[STRIDES:.*]]: tensor<2xindex>) -func.func @real_dynamic_slice(%input: tensor<256x?xf32>, %start_indices: tensor<2xindex>, %limit_indices: tensor<2xindex>, %strides: tensor<2xindex>) -> tensor<256x?xf32> { - %0 = "stablehlo.real_dynamic_slice"(%input, %start_indices, %limit_indices, %strides) : (tensor<256x?xf32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<256x?xf32> - func.return %0 : tensor<256x?xf32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 - -// Fetch start index, limit index and stride. -// CHECK-DAG: %[[START0:.*]] = tensor.extract %[[START_INDICES]][%[[C0]]] -// CHECK-DAG: %[[STRIDE0:.*]] = tensor.extract %[[STRIDES]][%[[C0]]] - -// Clamp starting index : 0 <= start <= ub -// CHECK-DAG: %[[MAX0:.*]] = arith.maxsi %[[START0]], %[[C0]] : index -// CHECK-DAG: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C0]] : index - -// CHECK-DAG: %[[START1:.*]] = tensor.extract %[[START_INDICES]][%[[C1]]] -// CHECK-DAG: %[[LIMIT1:.*]] = tensor.extract %[[LIMIT_INDICES]][%[[C1]]] -// CHECK-DAG: %[[STRIDE1:.*]] = tensor.extract %[[STRIDES]][%[[C1]]] - -// 2.2. Since 1-th dimension of result is unknown we compute result size at 1-th -// dimension as size[1] = (limit - start)/stride -// CHECK-DAG: %[[DELTA1:.*]] = arith.subi %[[LIMIT1]], %[[START1]] : index -// CHECK-DAG: %[[SIZE1:.*]] = arith.ceildivui %[[DELTA1]], %[[STRIDE1]] : index - -// 2.3. Compute upper bound for starting index = operand_dim[1] - size[1]. -// where, size[1] is computed at step 2.2 -// CHECK-DAG: %[[OPERAND_DIM1:.*]] = tensor.dim %[[OPERAND]], %[[C1]] : tensor<256x?xf32> -// CHECK-DAG: %[[UB:.*]] = arith.subi %[[OPERAND_DIM1]], %[[SIZE1]] : index - -// 2.4. Clamp starting index : 0 <= start <= ub -// where upper bound (ub) is computed at step 2.3 -// CHECK-DAG: %[[MAX1:.*]] = arith.maxsi %[[START1]], %[[C0]] : index -// CHECK-DAG: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[UB]] : index - -// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[OPERAND]][%[[MIN0]], %[[MIN1]]] [256, %[[SIZE1]]] [%[[STRIDE0]], %[[STRIDE1]]] : tensor<256x?xf32> to tensor<256x?xf32> -// CHECK: return %[[SLICE]] : tensor<256x?xf32> - -// ----- - -// Verify that legalization of real_dynamic_slice legalization with integer -// dims work & passes verification. -// CHECK-LABEL: real_dynamic_slice_with_int -func.func @real_dynamic_slice_with_int(%arg0: tensor<10xi32> , %arg1: tensor<1xi32> ) -> tensor { - %0 = stablehlo.constant dense<0> : tensor<1xi32> - %1 = stablehlo.constant dense<1> : tensor<1xi32> - %2 = stablehlo.constant dense<0> : tensor - %4 = "stablehlo.real_dynamic_slice"(%arg0, %0, %arg1, %1) : (tensor<10xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - func.return %4 : tensor -} - -// ----- - -// CHECK-LABEL: func @reshape_0D_1D -func.func @reshape_0D_1D(%arg0: tensor) -> tensor<1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> - func.return %0 : tensor<1xi32> -} -// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1] : tensor into tensor<1xi32> - -// ----- - -// CHECK-LABEL: func @reshape_0D_1D_unsigned -// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] -func.func @reshape_0D_1D_unsigned(%arg0: tensor) -> tensor<1xui32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<1xui32> - func.return %0 : tensor<1xui32> -} -// CHECK: %[[ARG_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor to tensor -// CHECK: %[[RET_SIGNLESS:.*]] = tensor.expand_shape %[[ARG_SIGNLESS]] [] output_shape [1] : tensor into tensor<1xi32> -// CHECK: %[[RET_UNSIGNED:.*]] = builtin.unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32> -// CHECK: return %[[RET_UNSIGNED]] : tensor<1xui32> - -// ----- - -// CHECK-LABEL: func @reshape_1D_0D -func.func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor - -// ----- - -// CHECK-LABEL: func @reshape_1D_0D_unsigned -// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]] -func.func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1xui32>) -> tensor - func.return %0 : tensor -} -// CHECK: %[[ARG_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32> -// CHECK: %[[RET_SIGNLESS:.*]] = tensor.collapse_shape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor -// CHECK: %[[RET_UNSIGNED:.*]] = builtin.unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor to tensor -// CHECK: return %[[RET_UNSIGNED]] : tensor - -// ----- - -// CHECK-LABEL: func @reshape_3D_2D -func.func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> - func.return %0 : tensor<12x42xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2]] - -// ----- - -// CHECK-LABEL: func @reshape_4D_2D -func.func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> - func.return %0 : tensor<12x42xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] - -// ----- - -// CHECK-LABEL: func @reshape_2D_4D -func.func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> - func.return %0 : tensor<12x1x42x1xi32> -} -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2, 3]] - -// ----- - -// CHECK-LABEL: func @reshape_3D_4D -func.func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> - func.return %0 : tensor<1x784x1x1xf32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2]] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3]] - -// ----- - -// CHECK-LABEL: func @reshape_4D_3D -func.func @reshape_4D_3D(%arg0: tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1x8x10x3xf32>) -> tensor<1x240x1xf32> - func.return %0 : tensor<1x240x1xf32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2] - -// ----- - -// CHECK-LABEL: func @reshape1_4D_4D -func.func @reshape1_4D_4D(%arg0: tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<4x512x1x1xi32>) -> tensor<1x4x1x512xi32> - func.return %0 : tensor<1x4x1x512xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3] - -// ----- - -// CHECK-LABEL: func @reshape2_4D_4D -func.func @reshape2_4D_4D(%arg0: tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<4x1x1x1024xi32>) -> tensor<4x1024x1x1xi32> - func.return %0 : tensor<4x1024x1x1xi32> -} -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3] - -// ----- - -// CHECK-LABEL: func @reshape_dynamic_in -func.func @reshape_dynamic_in(%arg0: tensor) -> tensor<2x4x5xf32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<2x4x5xf32> - func.return %0 : tensor<2x4x5xf32> -} -// CHECK: %[[FLATTEN:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] : tensor into tensor -// CHECK: %[[CAST:.*]] = tensor.cast %[[FLATTEN]] : tensor to tensor<40xf32> -// CHECK: tensor.expand_shape %[[CAST]] {{\[}}[0, 1, 2]] output_shape [2, 4, 5] : tensor<40xf32> into tensor<2x4x5xf32> - -// ----- - -// CHECK-LABEL: func @reshape_1D_2D_dynamic -func.func @reshape_1D_2D_dynamic(%arg0: tensor) -> tensor<1x3xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<1x3xi32> - func.return %0 : tensor<1x3xi32> -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor to tensor<3xi32> -// CHECK: tensor.expand_shape %[[CAST]] {{\[}}[0, 1]] output_shape [1, 3] : tensor<3xi32> into tensor<1x3xi32> - -// ----- - -// CHECK-LABEL: func @reshape_2D_1D_dynamic -func.func @reshape_2D_1D_dynamic(%arg0: tensor) -> tensor<3xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} -// CHECK: %[[FLATTEN:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0, 1]] : tensor into tensor -// CHECK: %[[CAST:.*]] = tensor.cast %[[FLATTEN]] : tensor to tensor<3xi32> -// CHECK: return %[[CAST:.*]] : tensor<3xi32> - -// ----- -// CHECK-LABEL: func @reshape_2D_1D_semidynamic -func.func @reshape_2D_1D_semidynamic(%arg0: tensor<1x?xi32>) -> tensor<1xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<1x?xi32>) -> tensor<1xi32> - func.return %0 : tensor<1xi32> -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xi32> to tensor<1x1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}[0, 1]] : tensor<1x1xi32> into tensor<1xi32> -// CHECK: return %[[COLLAPSE:.*]] : tensor<1xi32> - -// ----- - -// CHECK-LABEL: func @reshape_1D_0D_dynamic -func.func @reshape_1D_0D_dynamic(%arg0: tensor) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor to tensor<1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}] : tensor<1xi32> into tensor -// CHECK: return %[[COLLAPSE:.*]] : tensor - -// ----- - -// CHECK-LABEL: func @reshape_2D_0D_dynamic -func.func @reshape_2D_0D_dynamic(%arg0: tensor) -> tensor { - %0 = "stablehlo.reshape"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor to tensor<1x1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}] : tensor<1x1xi32> into tensor -// CHECK: return %[[COLLAPSE:.*]] : tensor - -// ----- - -// CHECK-LABEL: func @reshape_3D_1D_semidynamic -func.func @reshape_3D_1D_semidynamic(%arg0: tensor<16x1x?xi32>) -> tensor<16xi32> { - %0 = "stablehlo.reshape"(%arg0) : (tensor<16x1x?xi32>) -> tensor<16xi32> - func.return %0 : tensor<16xi32> -} -// CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<16x1x?xi32> to tensor<16x1x1xi32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[CAST]] {{\[}}[0, 1, 2]] : tensor<16x1x1xi32> into tensor<16xi32> -// CHECK: return %[[COLLAPSE:.*]] : tensor<16xi32> - -// ----- - -// CHECK-LABEL: func @reshape_empty -func.func @reshape_empty(%arg0: tensor<7x0xf64>) -> tensor<0x42x101xf64> { - %0 = stablehlo.reshape %arg0 : (tensor<7x0xf64>) -> tensor<0x42x101xf64> - return %0 : tensor<0x42x101xf64> -} - -// CHECK: %[[INIT:.*]] = tensor.empty -// CHECK: return %[[INIT]] - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @reverse -func.func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { - %result = "stablehlo.reverse"(%input) { - dimensions = array, someattr - } : (tensor<2x3xf32>) -> tensor<2x3xf32> - func.return %result : tensor<2x3xf32> -} -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-SAME: {someattr} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: func.func @select_and_scatter -func.func @select_and_scatter(%arg0 : tensor<2x8x8x1xf32>, %arg1 : tensor<2x4x4x1xf32>, %arg2 : tensor) -> tensor<2x8x8x1xf32> { - %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - %9 = stablehlo.compare GE, %arg3, %arg4, FLOAT : (tensor, tensor) -> tensor - stablehlo.return %9 : tensor - }, { - ^bb0(%arg3: tensor, %arg4: tensor): - %9 = stablehlo.add %arg3, %arg4 : tensor - stablehlo.return %9 : tensor - }) { - padding = dense<0> : tensor<4x2xi64>, - window_dimensions = array, - window_strides = array - } : (tensor<2x8x8x1xf32>, tensor<2x4x4x1xf32>, tensor) -> tensor<2x8x8x1xf32> - - return %0 : tensor<2x8x8x1xf32> -} -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[CN1:.+]] = arith.constant -1 : i32 -// CHECK: %[[EMPTY_VAL:.+]] = tensor.empty() : tensor<2x4x4x1xf32> -// CHECK: %[[EMPTY_IDX:.+]] = tensor.empty() : tensor<2x4x4x1xi32> -// CHECK: %[[FILL_IDX:.+]] = linalg.fill ins(%[[CN1]] : i32) outs(%[[EMPTY_IDX]] : tensor<2x4x4x1xi32>) -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<2x2xf32> -// CHECK: %[[SELECT_GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map1, #map2, #map2] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} -// CHECK-SAME: ins(%arg0, %[[WINDOW]] : tensor<2x8x8x1xf32>, tensor<2x2xf32>) -// CHECK-SAME: outs(%[[EMPTY_VAL]], %[[FILL_IDX]] : tensor<2x4x4x1xf32>, tensor<2x4x4x1xi32>) -// CHECK: ^bb0(%[[VAL:.+]]: f32, %[[IDX:.+]]: f32, %[[OLD_VAL:.+]]: f32, %[[OLD_IDX:.+]]: i32): -// CHECK: %[[CMPF:.+]] = arith.cmpf oge, %[[VAL]], %[[OLD_VAL]] : f32 -// CHECK: %[[CMPI:.+]] = arith.cmpi eq, %[[OLD_IDX]], %[[CN1]] : i32 -// CHECK: %[[PRED:.+]] = arith.ori %[[CMPF]], %[[CMPI]] : i1 -// CHECK: %[[IDX4:.+]] = linalg.index 4 : index -// CHECK: %[[IDX5:.+]] = linalg.index 5 : index -// CHECK: %[[MUL:.+]] = arith.muli %[[IDX4]], %[[C2]] : index -// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[IDX5]] : index -// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]] : index to i32 -// CHECK: %[[SEL_IDX:.+]] = arith.select %[[PRED]], %[[CAST]], %[[OLD_IDX]] : i32 -// CHECK: %[[SEL_VAL:.+]] = arith.select %[[PRED]], %[[VAL]], %[[OLD_VAL]] : f32 -// CHECK: linalg.yield %[[SEL_VAL]], %[[SEL_IDX]] : f32, i32 - -// CHECK: %[[SCATTER_EMPTY:.+]] = tensor.empty() : tensor<2x4x2x4x2x1xf32> -// CHECK: %[[INIT:.+]] = tensor.extract %arg2[] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT]] : f32) outs(%[[SCATTER_EMPTY]] : tensor<2x4x2x4x2x1xf32>) -> tensor<2x4x2x4x2x1xf32> -// CHECK: %[[SCATTER:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#map3, #map3, #map4] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} -// CHECK-SAME: ins(%[[SELECT_GENERIC]]#1, %arg1 : tensor<2x4x4x1xi32>, tensor<2x4x4x1xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x4x2x4x2x1xf32>) -// CHECK: ^bb0(%[[IDX:.+]]: i32, %[[UPDATE:.+]]: f32, %[[OLD:.+]]: f32): -// CHECK: %[[NEW:.+]] = arith.addf %[[UPDATE]], %[[OLD]] : f32 -// CHECK: %[[IDX2:.+]] = linalg.index 2 : index -// CHECK: %[[IDX4:.+]] = linalg.index 4 : index -// CHECK: %[[MUL:.+]] = arith.muli %[[IDX2]], %[[C2]] : index -// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[IDX4]] : index -// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]] : index to i32 -// CHECK: %[[CMP:.+]] = arith.cmpi eq, %[[CAST]], %[[IDX]] : i32 -// CHECK: %[[SELECT:.+]] = arith.select %[[CMP]], %[[NEW]], %[[OLD]] : f32 -// CHECK: linalg.yield %[[SELECT]] - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]] -// CHECK-NEXT{literal}: [[0], [1, 2], [3, 4], [5]] -// CHECK: return %[[COLLAPSE]] : tensor<2x8x8x1xf32> - -// ----- - -// CHECK-LABEL: set_dimension_size -// CHECK-SAME: %[[VALUE:.*]]: tensor<2x?xf32, #stablehlo.bounds -func.func @set_dimension_size( - %value: tensor<2x?xf32, #stablehlo.type_extensions>, - %dimension: tensor) - -> tensor<2x?xf32, #stablehlo.type_extensions> { - // CHECK: tensor.extract_slice %[[VALUE]][0, 0] [2, %{{.*}}] [1, 1] : tensor<2x?xf32, #stablehlo.bounds> to tensor<2x?xf32, #stablehlo.bounds> - %0 = "stablehlo.set_dimension_size"(%value, %dimension) { dimension = 1 } - : (tensor<2x?xf32, #stablehlo.type_extensions>, tensor) - -> tensor<2x?xf32, #stablehlo.type_extensions> - func.return %0 : tensor<2x?xf32, #stablehlo.type_extensions> -} - -// ----- - -func.func @torch_index_select(%arg0: tensor<5x1x5xi32>, - %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - dim = 0 : i64, - batch_dims = 0 : i64, - someattr - } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> - func.return %0 : tensor<2x1x5xi32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @torch_index_select -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -// CHECK: %[[INIT1:.+]] = tensor.empty() : -// CHECK: %[[INIT2:.+]] = tensor.empty() : -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT1]] : -// CHECK-SAME: outs(%[[INIT2]] : -// CHECK-SAME: {someattr} -// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32, %{{.+}}: i32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[J:.+]] = linalg.index 1 -// CHECK: %[[K:.+]] = linalg.index 2 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> -// CHECK: linalg.yield %[[VAL2]] : i32 - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @torch_index_select_unsigned -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>, - %arg1: tensor<2xi32>) -> tensor<2x1x5xui32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - dim = 0 : i64, - batch_dims = 0 : i64 - } : (tensor<5x1x5xui32>, tensor<2xi32>) -> tensor<2x1x5xui32> - func.return %0 : tensor<2x1x5xui32> -} -// CHECK: %[[INPUT_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32> -// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x5xi32> -// CHECK: %[[RES:.+]] = linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT]] : tensor<2xi32>, tensor<1x5xi32>) -// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32, %{{.+}}: i32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[J:.+]] = linalg.index 1 -// CHECK: %[[K:.+]] = linalg.index 2 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT_SIGNLESS]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> -// CHECK: linalg.yield %[[VAL2]] : i32 -// CHECK: %[[RES_UNSIGNED:.+]] = builtin.unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32> -// CHECK: return %[[RES_UNSIGNED]] - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> -// CHECK: func @torch_index_select_scalar -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_scalar(%arg0: tensor<4x8xf32>, - %arg1: tensor) -> tensor<8xf32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - batch_dims = 0 : i64, - dim = 0 : i64 - } : (tensor<4x8xf32>, tensor) -> tensor<8xf32> - func.return %0 : tensor<8xf32> -} -// CHECK: %[[T0:.+]] = tensor.empty() : tensor<8xf32> -// CHECK: %[[T1:.+]] = tensor.empty() : tensor<8xf32> -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP1]] -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[T0]] : tensor, tensor<8xf32>) outs(%[[T1]] : tensor<8xf32>) -// CHECK: ^{{.+}}(%[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[I:.+]] = linalg.index 0 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32> -// CHECK: linalg.yield %[[VAL2]] : f32 - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @torch_index_select_batch -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>, - %arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> { - %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { - dim = 2 : i64, - batch_dims = 1 : i64 - } : (tensor<4x7x8x2xf32>, tensor<4x1xi32>) -> tensor<4x7x1x2xf32> - func.return %0 : tensor<4x7x1x2xf32> -} -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7x2xf32> -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT]] : -// CHECK-NEXT: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: f32, %{{.+}}: f32): -// CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index -// CHECK: %[[I:.+]] = linalg.index 0 -// CHECK: %[[J:.+]] = linalg.index 1 -// CHECK: %[[L:.+]] = linalg.index 3 -// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32> -// CHECK: linalg.yield %[[VAL2]] : f32 - -// ----- - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @torch_index_select_dynamic -// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -func.func @torch_index_select_dynamic(%input: tensor, - %index: tensor) -> tensor{ - %0 = "stablehlo.torch_index_select"(%input, %index) { - batch_dims = 1 : i64, - dim = 2 : i64 - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] -// CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] -// CHECK: %[[D2:.+]] = tensor.dim %[[INDEX]], %[[C1]] -// CHECK: %[[D3:.+]] = tensor.dim %[[INPUT]], %[[C3]] -// CHECK: %[[D4:.+]] = tensor.dim %[[INPUT]], %[[C0]] -// CHECK: %[[D5:.+]] = tensor.dim %[[INPUT]], %[[C1]] -// CHECK: %[[D6:.+]] = tensor.dim %[[INPUT]], %[[C3]] -// CHECK: %[[INIT0:.+]] = tensor.empty(%[[D4]], %[[D5]], %[[D6]]) : tensor -// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]], %[[INIT0]] : tensor, tensor) -// CHECK-SAME: outs(%[[INIT1]] : tensor) -// CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32, %{{[a-zA-Z0-9_]+}}: f32) -// CHECK: %[[POS:.+]] = arith.index_cast %[[ARG0]] -// CHECK: %[[IDX0:.+]] = linalg.index 0 -// CHECK: %[[IDX1:.+]] = linalg.index 1 -// CHECK: %[[IDX3:.+]] = linalg.index 3 -// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[IDX0]], %[[IDX1]], %[[POS]], %[[IDX3]]] -// CHECK: linalg.yield %[[YIELD]] - -// ----- - -// CHECK-LABEL: func @slice_whole_stride -// CHECK: tensor.extract_slice %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32> -func.func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { - %0 = "stablehlo.slice"(%arg0) { - start_indices = array, - limit_indices = array, - strides = array - } : (tensor<3x4xi32>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_stride_part -// CHECK: tensor.extract_slice %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32> -func.func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { - %0 = "stablehlo.slice"(%arg0) { - start_indices = array, - limit_indices = array, - strides = array - } : (tensor<3x4xi32>) -> tensor<1x2xi32> - func.return %0 : tensor<1x2xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_with_strides -// CHECK: tensor.extract_slice %{{.*}}[0] [6] [2] : tensor<13xi32> to tensor<6xi32> -func.func @slice_with_strides(%arg0: tensor<13xi32>) -> tensor<6xi32> { - %0 = "stablehlo.slice"(%arg0) { - limit_indices = array, - start_indices = array, - strides = array - } : (tensor<13xi32>) -> tensor<6xi32> - func.return %0 : tensor<6xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_with_strides -// CHECK: tensor.extract_slice %{{.*}}[0] [3] [2] : tensor<6xi32> to tensor<3xi32> -func.func @slice_with_strides2(%arg0: tensor<6xi32>) -> tensor<3xi32> { - %0 = "stablehlo.slice"(%arg0) { - limit_indices = array, - start_indices = array, - strides = array - } : (tensor<6xi32>) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} - -// ----- - -// CHECK-LABEL: func @slice_with_empty_result -// CHECK: tensor.extract_slice %{{.*}}[0, 2, 0] [3, 0, 5] [1, 2, 1] : tensor<3x3x5xf64> to tensor<3x0x5xf64> -func.func @slice_with_empty_result(%arg0: tensor<3x3x5xf64>) -> tensor<3x0x5xf64> { - %0 = "stablehlo.slice"(%arg0) { - limit_indices = array, - start_indices = array, - strides = array - } : (tensor<3x3x5xf64>) -> tensor<3x0x5xf64> - func.return %0 : tensor<3x0x5xf64> -} - -// ----- - -// CHECK-LABEL: func @dynamic_slice( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor) -> tensor<1x4xf32> { - %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = array - } : (tensor<3x4xf32>, tensor, tensor) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C2]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C0]] : index -// CHECK: tensor.extract_slice %[[ARG0]][%[[CLAMPED1]], %[[CLAMPED2]]] [1, 4] [1, 1] - -// ----- - -// CHECK-LABEL: func @dynamic_slice_unsigned_index( -func.func @dynamic_slice_unsigned_index( - %arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) - -> tensor<1x4xui32> { - %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = array - } : (tensor<3x4xui32>, tensor, tensor) -> tensor<1x4xui32> - func.return %0 : tensor<1x4xui32> -} - -// CHECK: %[[EXTRACT1:.*]] = tensor.extract -// CHECK: arith.index_castui %[[EXTRACT1]] - -// ----- - -// CHECK-LABEL: func @dynamic_slice_unsigned( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) -> tensor<1x4xui32> { - %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = array - } : (tensor<3x4xui32>, tensor, tensor) -> tensor<1x4xui32> - func.return %0 : tensor<1x4xui32> -} - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[SIGNLESS_ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x4xui32> to tensor<3x4xi32> -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C2]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C0]] : index -// CHECK: tensor.extract_slice %[[SIGNLESS_ARG0]][%[[CLAMPED1]], %[[CLAMPED2]]] [1, 4] [1, 1] - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_update_slice(%target: tensor<3x3xi32>, %update: tensor<2x2xi32>, %c0: tensor) -> tensor<3x3xi32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %c0, %c0) - : (tensor<3x3xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<3x3xi32> - func.return %0 : tensor<3x3xi32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C1]] : index -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG1]] into %[[ARG0]] -// CHECK-SAME: [%[[CLAMPED1]], %[[CLAMPED2]]] [2, 2] [1, 1] -// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32> -// CHECK: return %[[RES]] : tensor<3x3xi32> - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice_unsigned_index( -func.func @dynamic_update_slice_unsigned_index( - %target: tensor<3x3xi32>, %update: tensor<2x2xi32>, - %idx: tensor) -> tensor<3x3xi32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %idx, %idx) - : (tensor<3x3xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<3x3xi32> - func.return %0 : tensor<3x3xi32> -} - -// CHECK: %[[EXTRACT1:.*]] = tensor.extract -// CHECK: arith.index_castui %[[EXTRACT1]] - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice_unsigned( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_update_slice_unsigned(%target: tensor<3x3xui32>, %update: tensor<2x2xui32>, %c0: tensor) -> tensor<3x3xui32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %c0, %c0) - : (tensor<3x3xui32>, tensor<2x2xui32>, tensor, tensor) -> tensor<3x3xui32> - func.return %0 : tensor<3x3xui32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[SIGNLESS_UPDATE:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : tensor<2x2xui32> to tensor<2x2xi32> -// CHECK-DAG: %[[SIGNLESS_TARGET:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x3xui32> to tensor<3x3xi32> -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C1]] : index -// CHECK: %[[SIGNLESS_RES:.*]] = tensor.insert_slice %[[SIGNLESS_UPDATE]] into %[[SIGNLESS_TARGET]] -// CHECK-SAME: [%[[CLAMPED1]], %[[CLAMPED2]]] [2, 2] [1, 1] -// CHECK-SAME: : tensor<2x2xi32> into tensor<3x3xi32> -// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[SIGNLESS_RES]] : tensor<3x3xi32> to tensor<3x3xui32> -// CHECK: return %[[RES]] : tensor<3x3xui32> - -// ----- - -// CHECK-LABEL: func @dynamic_update_slice_float( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] -func.func @dynamic_update_slice_float(%target: tensor<3x3xf32>, - %update: tensor<2x2xf32>, - %c0: tensor) -> tensor<3x3xf32> { - %0 = "stablehlo.dynamic_update_slice"(%target, %update, %c0, %c0) - : (tensor<3x3xf32>, tensor<2x2xf32>, tensor, tensor) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTRACT1:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR1:.*]] = arith.index_cast %[[EXTRACT1]] -// CHECK: %[[T1:.*]] = arith.maxsi %[[SCALAR1]], %[[C0]] : index -// CHECK: %[[CLAMPED1:.*]] = arith.minsi %[[T1]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.*]] = tensor.extract %[[ARG2]][] : tensor -// CHECK: %[[SCALAR2:.*]] = arith.index_cast %[[EXTRACT2]] -// CHECK: %[[T2:.*]] = arith.maxsi %[[SCALAR2]], %[[C0]] : index -// CHECK: %[[CLAMPED2:.*]] = arith.minsi %[[T2]], %[[C1]] : index -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG1]] into %[[ARG0]] -// CHECK-SAME: [%[[CLAMPED1]], %[[CLAMPED2]]] [2, 2] [1, 1] -// CHECK-SAME: : tensor<2x2xf32> into tensor<3x3xf32> -// CHECK: return %[[RES]] : tensor<3x3xf32> - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3, d2)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @transpose -func.func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - %0 = "stablehlo.transpose"(%arg0) {permutation = array} - : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> - func.return %0 : tensor<3x2x5x9xi32> -} -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] - -// CHECK-PRIMITIVE-LABEL: func @transpose -// CHECK-PRIMITIVE: linalg.transpose - -// ----- - -// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3, d2)> -// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @transpose_dynamic -func.func @transpose_dynamic(%arg0: tensor) -> tensor { - %0 = "stablehlo.transpose"(%arg0) {permutation = array, someattr} - : (tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] -// CHECK: %[[D1:.*]] = tensor.dim %arg0, %[[C1]] -// CHECK: %[[D3:.*]] = tensor.dim %arg0, %[[C3]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D0]], %[[D3]]) : tensor -// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-SAME: ins(%arg0 : tensor) outs(%[[INIT]] : tensor) -// CHECK-SAME: {someattr} - -// CHECK-PRIMITIVE-LABEL: func @transpose_dynamic -// CHECK-PRIMITIVE-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-PRIMITIVE-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-PRIMITIVE: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] -// CHECK-PRIMITIVE: %[[D1:.*]] = tensor.dim %arg0, %[[C1]] -// CHECK-PRIMITIVE: %[[D3:.*]] = tensor.dim %arg0, %[[C3]] -// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D0]], %[[D3]]) : tensor -// CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE-SAME: ins(%arg0 : tensor) -// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) -// CHECK-PRIMITIVE-SAME: permutation = [1, 0, 3, 2] -// CHECK-PRIMITIVE-SAME: {someattr} - -// ----- - -func.func @transpose_unsigned(%arg0: tensor<2x2xui32>) -> tensor<2x2xui32> { - %0 = "stablehlo.transpose"(%arg0) { - permutation = array, - result_layout = dense<[0, 1]> : tensor<2xindex> - } : (tensor<2x2xui32>) -> tensor<2x2xui32> - return %0 : tensor<2x2xui32> -} - -// Regression test. Just check that unsigned ints lower successfully. -// CHECK-LABEL: func @transpose_unsigned -// CHECK-PRIMITIVE-LABEL: func @transpose_unsigned diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_convolution.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_convolution.mlir deleted file mode 100644 index 6a4702e83031..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_convolution.mlir +++ /dev/null @@ -1,595 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// CHECK-LABEL: @linalg.conv_0d_nc -func.func @linalg.conv_0d_nc(%arg0: tensor<3x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, f]x[i, o]->[b, f], - window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} - { - batch_group_count = 1 : i64, feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() -// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%cst{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<3x2xf32>, tensor<2x3xf32>) outs(%[[FILL]] : tensor<3x3xf32>) - -// ----- - -// CHECK-LABEL: func @linalg.conv_1d_nwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -func.func @linalg.conv_1d_nwc(%arg0: tensor, %arg1: tensor<2x?x?xf32>) - -> tensor { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0]]> : tensor<1x2xi64>, - rhs_dilation = array, - window_strides = array, - someattr - } : (tensor, tensor<2x?x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<2x?x?xf32> -// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]], %[[DIM2]]) -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.conv_1d_nwc_wcf -// CHECK-SAME: {dilations = dense<1> : tensor<1xi64> -// CHECK-SAME: someattr -// CHECK-SAME: strides = dense<1> : tensor<1xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<2x?x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @conv_2d_nhwc_hwcf -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_2d_nhwc_hwcf(%arg0: tensor, %arg1: tensor<3x2x?x?xf32>) - -> tensor { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array - } : (tensor, tensor<3x2x?x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[C3]] : tensor<3x2x?x?xf32> -// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]], %[[DIM3]]) -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.conv_2d_nhwc -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3x2x?x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @conv_transpose_2d -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_transpose_2d(%arg0: tensor<2x9x10x3xf32>, - %arg1: tensor<4x4x3x3xf32>) - -> tensor<2x15x25x3xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[6, 6], [6, 6]], - lhs_dilate = [1, 2], rhs_dilate = [2, 2]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, - #stablehlo] - } : (tensor<2x9x10x3xf32>, tensor<4x4x3x3xf32>) -> tensor<2x15x25x3xf32> - return %0 : tensor<2x15x25x3xf32> -} -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: %[[LHS_INIT:.+]] = tensor.empty() -// CHECK: %[[LHS_FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[LHS_INIT]] -// CHECK: %[[LHS_PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[LHS_FILL]][0, 6, 6, 0] [2, 9, 10, 3] [1, 1, 2, 1] : tensor<2x9x10x3xf32> into tensor<2x21x31x3xf32> -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: {dilations = dense<2> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[LHS_PAD]], %[[ARG1]] : tensor<2x21x31x3xf32>, tensor<4x4x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x15x25x3xf32>) -> tensor<2x15x25x3xf32> - -// ----- - -// CHECK-LABEL: func @conv_transpose_complex_2d -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_transpose_complex_2d(%arg0: tensor<2x9x10x3xcomplex>, - %arg1: tensor<4x4x3x3xcomplex>) - -> tensor<2x15x25x3xcomplex> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[6, 6], [6, 6]], - lhs_dilate = [1, 2], rhs_dilate = [2, 2]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, - #stablehlo] - } : (tensor<2x9x10x3xcomplex>, tensor<4x4x3x3xcomplex>) -> tensor<2x15x25x3xcomplex> - return %0 : tensor<2x15x25x3xcomplex> -} -// CHECK: %[[ZERO:.+]] = complex.constant [0.000000e+00 : f32, 0.000000e+00 : f32] : complex -// CHECK: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: %[[LHS_INIT:.+]] = tensor.empty() -// CHECK: %[[LHS_FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[LHS_INIT]] -// CHECK: %[[LHS_PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[LHS_FILL]][0, 6, 6, 0] [2, 9, 10, 3] [1, 1, 2, 1] : tensor<2x9x10x3xcomplex> into tensor<2x21x31x3xcomplex> -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: {dilations = dense<2> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[LHS_PAD]], %[[ARG1]] : tensor<2x21x31x3xcomplex>, tensor<4x4x3x3xcomplex>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x15x25x3xcomplex>) -> tensor<2x15x25x3xcomplex> - -// ----- - -// Just check that this lowers successfully. -// CHECK-LABEL: func @conv_different_batch_dim_in_out -func.func @conv_different_batch_dim_in_out(%arg0: tensor<1x1x1xf64>, - %arg1: tensor<1x1x1xf64>) - -> tensor<1x1x1xf64> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [f, 0, b]x[i, o, 0]->[f, b, 0], - window = {stride = [1], pad = [[0, 0]], lhs_dilate = [1], - rhs_dilate = [1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x1x1xf64>, tensor<1x1x1xf64>) -> tensor<1x1x1xf64> - return %0 : tensor<1x1x1xf64> -} - -// ----- - -// Just check that this lowers successfully. -// CHECK-LABEL: func @conv_different_batch_dim_in_out_with_feature_group_count -func.func @conv_different_batch_dim_in_out_with_feature_group_count( - %arg0: tensor<4x6x7x1xf64>, %arg1: tensor<2x6x3x2xf64>) - -> tensor<1x2x1x2xf64> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f], - window = {stride = [1, 1], pad = [[0, 0], [0, -1]], - lhs_dilate = [1, 1], rhs_dilate = [1, 2], - reverse = [0, 0]} - { - batch_group_count = 1 : i64, - feature_group_count = 2 : i64, - precision_config = [#stablehlo, #stablehlo] - } : (tensor<4x6x7x1xf64>, tensor<2x6x3x2xf64>) -> tensor<1x2x1x2xf64> - return %0 : tensor<1x2x1x2xf64> -} - -// ----- - -// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv_3d_ndhwc_dhwcf(%arg0: tensor, %arg1: tensor<2x2x2x?x?xf32>) - -> tensor { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>, - rhs_dilation = array, - window_strides = array - } : (tensor, tensor<2x2x2x?x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM4:.+]] = tensor.dim %[[ARG1]], %[[C4]] : tensor<2x2x2x?x?xf32> -// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]], %[[DIM4]]) -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.conv_3d_ndhwc_dhwcf -// CHECK-SAME: {dilations = dense<1> : tensor<3xi64> -// CHECK-SAME: strides = dense<1> : tensor<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<2x2x2x?x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>) - -> tensor<1x2x4x3xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 1 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array - } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32> - func.return %0 : tensor<1x2x4x3xf32> -} -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x4x3xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[INIT]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32> -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64> -// CHECK-SAME: strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32> - -// ----- - -// CHECK-LABEL: func @linalg.conv_2D_padding_test1 -// CHECK-SAME: (%[[FILTER:.*]]: tensor<1x33x1x1xf16>, %[[INPUT:.*]]: tensor<400x1024x1024x1xf16>) -func.func @linalg.conv_2D_padding_test1(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<400x1024x1024x1xf16>) - -> tensor<400x1024x1024x1xf16> { - %0 = stablehlo.convolution(%arg1, %arg0) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = { stride = [1, 1], pad = [[0, 0], [16, 16]], rhs_dilate = [1, 1] } - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64 - } : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1024x1024x1xf16>) - func.return %0 : tensor<400x1024x1024x1xf16> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16 -// CHECK-NEXT: %[[INIT:.*]] = tensor.empty() : tensor<400x1024x1024x1xf16> -// CHECK-NEXT: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16> -// CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 0, 16, 0] high[0, 0, 16, 0] { -// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK-NEXT: tensor.yield %[[ZERO]] : f16 -// CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1024x1056x1xf16> -// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[PAD]], %[[FILTER]] : tensor<400x1024x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%[[FILL]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16> -// CHECK-NEXT: return %[[RESULT]] : tensor<400x1024x1024x1xf16> - -// ----- - -// CHECK-LABEL: func @linalg.conv_2D_padding_test2 -// CHECK-SAME: (%[[FILTER:.*]]: tensor<1x33x1x1xf16>, %[[INPUT:.*]]: tensor<400x1024x1024x1xf16>) -func.func @linalg.conv_2D_padding_test2(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<400x1024x1024x1xf16>) - -> tensor<400x1040x1024x1xf16> { - %0 = stablehlo.convolution(%arg1, %arg0) - dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[8, 8], [16, 16]], rhs_dilate = [1, 1]} - { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64 - } : (tensor<400x1024x1024x1xf16>, tensor<1x33x1x1xf16>) -> (tensor<400x1040x1024x1xf16>) - return %0 : tensor<400x1040x1024x1xf16> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16 -// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<400x1040x1024x1xf16> -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]] : f16) outs(%[[INIT]] : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16> -// CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 8, 16, 0] high[0, 8, 16, 0] { -// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK-NEXT: tensor.yield %[[ZERO]] : f16 -// CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1040x1056x1xf16> -// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[PAD]], %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%[[FILL]] : tensor<400x1040x1024x1xf16>) -> tensor<400x1040x1024x1xf16> -// CHECK-NEXT: return %[[RESULT]] : tensor<400x1040x1024x1xf16> - -// ----- - -// CHECK-LABEL: func @depthwise_conv -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv(%arg0: tensor<2x4x5x2xf32>, - %arg1: tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 2 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array, - someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> - func.return %0 : tensor<2x3x4x6xf32> -} -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2, 3]] : tensor<2x2x1x6xf32> into tensor<24xf32> -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1, 2, 3]] output_shape [2, 2, 2, 3] : tensor<24xf32> into tensor<2x2x2x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x4x2x3xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> -// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[IN]], %[[EXPAND]] : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32> -// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]] -// CHECK-SAME: [0], [1], [2], [3, 4] -// CHECK-SAME: : tensor<2x3x4x2x3xf32> into tensor<2x3x4x6xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv_with_padding -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv_with_padding( - %arg0: tensor<2x4x5x2xf32>, - %arg1: tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 2 : i64, - padding = dense<[[0, 0], [1, 1]]> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array, - someattr} : (tensor<2x4x5x2xf32>, tensor<2x2x1x4xf32>) -> tensor<2x3x6x4xf32> - func.return %0 : tensor<2x3x6x4xf32> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[PAD:.*]] = tensor.pad %[[IN]] low[0, 0, 1, 0] high[0, 0, 1, 0] { -// CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK: tensor.yield %[[ZERO]] : f32 -// CHECK } : tensor<2x4x5x2xf32> to tensor<2x4x7x2xf32> -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FILTER]] -// CHECK-SAME: [0, 1, 2, 3] -// CHECK-SAME: : tensor<2x2x1x4xf32> into tensor<16xf32> -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] -// CHECK-SAME: [0, 1, 2, 3] -// CHECK-SAME: tensor<16xf32> into tensor<2x2x2x2xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x6x2x2xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[INIT]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32> -// CHECK: %[[OUT:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, someattr, strides = dense<1> : tensor<2xi64>} -// CHECK-SAME: ins(%[[PAD]], %[[EXPAND]] : tensor<2x4x7x2xf32>, tensor<2x2x2x2xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x3x6x2x2xf32>) -> tensor<2x3x6x2x2xf32> -// CHECK: %{{.+}} = tensor.collapse_shape %[[OUT]] -// CHECK-SAME: [0], [1], [2], [3, 4] -// CHECK-SAME: : tensor<2x3x6x2x2xf32> into tensor<2x3x6x4xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv_multiplier_1 -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv_multiplier_1(%arg0: tensor<1x113x113x96xf32>, - %arg1: tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 96 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32> - func.return %0 : tensor<1x56x56x96xf32> -} -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x56x56x96xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] -// CHECK-SAME: [0], [1], [2, 3] -// CHECK-SAME: : tensor<3x3x1x96xf32> into tensor<3x3x96xf32> -// CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} -// CHECK-SAME: ins(%[[IN]], %[[RESHAPED_FILTER]] : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv_multiplier_1_with_padding -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv_multiplier_1_with_padding( - %arg0: tensor<1x113x113x96xf32>, - %arg1: tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> { - %0 = "stablehlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #stablehlo.conv, - feature_group_count = 96 : i64, - padding = dense<[[1, 1], [2, 2]]> : tensor<2x2xi64>, - rhs_dilation = array, - window_strides = array} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x57x58x96xf32> - func.return %0 : tensor<1x57x58x96xf32> -} -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[PAD:.*]] = tensor.pad %[[IN]] low[0, 1, 2, 0] high[0, 1, 2, 0] { -// CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): -// CHECK: tensor.yield %[[ZERO]] : f32 -// CHECK } : tensor<1x113x113x96xf32> to tensor<1x115x117x96xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x58x96xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[INIT]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] -// CHECK-SAME: [0], [1], [2, 3] -// CHECK-SAME: : tensor<3x3x1x96xf32> into tensor<3x3x96xf32> -// CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc -// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} -// CHECK-SAME: ins(%[[PAD]], %[[RESHAPED_FILTER]] : tensor<1x115x117x96xf32>, tensor<3x3x96xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x57x58x96xf32>) -> tensor<1x57x58x96xf32> - -// ----- - -// CHECK-LABEL: func @depthwise_conv1d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv1d(%arg0: tensor<1x10x8xf32>, - %arg1: tensor<3x1x16xf32>) -> tensor<1x10x16xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], - window = { - stride = [1], - pad = [[1, 1]], - lhs_dilate = [1], - rhs_dilate = [1], - reverse = [0]} { - batch_group_count = 1 : i64, - feature_group_count = 8 : i64, - someattr} : (tensor<1x10x8xf32>, tensor<3x1x16xf32>) -> tensor<1x10x16xf32> - func.return %0 : tensor<1x10x16xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_1d_nwc_wcm -// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[CONV]] -// CHECK: return %[[OUT]] - -// ----- - -// CHECK-LABEL: func @depthwise_conv1d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv1d_m1(%arg0: tensor<1x10x8xf32>, - %arg1: tensor<3x1x8xf32>) -> tensor<1x10x8xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], - window = { - stride = [1], - pad = [[1, 1]], - lhs_dilate = [1], - rhs_dilate = [1], - reverse = [0]} { - batch_group_count = 1 : i64, - feature_group_count = 8 : i64, - someattr} : (tensor<1x10x8xf32>, tensor<3x1x8xf32>) -> tensor<1x10x8xf32> - func.return %0 : tensor<1x10x8xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_1d_nwc_wc -// CHECK: return %[[CONV]] - -// ----- - -// CHECK-LABEL: func @depthwise_conv3d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv3d(%arg0: tensor<2x3x5x4x6xf32>, - %arg1: tensor<2x1x3x1x36xf32>) - -> tensor<2x3x13x4x36xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], - window = { - stride = [2, 1, 3], - pad = [[1, 2], [5, 3], [3, 5]], - lhs_dilate = [1, 1, 1], - rhs_dilate = [1, 1, 1], - reverse = [0, 0, 0]} { - batch_group_count = 1 : i64, - feature_group_count = 6 : i64, - someattr} : (tensor<2x3x5x4x6xf32>, tensor<2x1x3x1x36xf32>) - -> tensor<2x3x13x4x36xf32> - func.return %0 : tensor<2x3x13x4x36xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_3d_ndhwc_dhwcm -// CHECK: %[[OUT:.+]] = tensor.collapse_shape %[[CONV]] -// CHECK: return %[[OUT]] - -// ----- - -// CHECK-LABEL: func @depthwise_conv3d -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[FILTER:[a-zA-Z0-9_]*]] -func.func @depthwise_conv3d_m1(%arg0: tensor<2x3x5x4x6xf32>, - %arg1: tensor<2x1x3x1x6xf32>) - -> tensor<2x3x13x4x6xf32> { - %0 = stablehlo.convolution(%arg0, %arg1) - dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], - window = { - stride = [2, 1, 3], - pad = [[1, 2], [5, 3], [3, 5]], - lhs_dilate = [1, 1, 1], - rhs_dilate = [1, 1, 1], - reverse = [0, 0, 0]} { - batch_group_count = 1 : i64, - feature_group_count = 6 : i64, - someattr} : (tensor<2x3x5x4x6xf32>, tensor<2x1x3x1x6xf32>) - -> tensor<2x3x13x4x6xf32> - func.return %0 : tensor<2x3x13x4x6xf32> -} -// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_3d_ndhwc_dhwc -// CHECK: return %[[CONV]] diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_dot_prod.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_dot_prod.mlir deleted file mode 100644 index 1c4f3e2bfcbb..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_dot_prod.mlir +++ /dev/null @@ -1,276 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// Note: We need the canonicalization pass to deduplicate constants. This test -// does not rely on it to simplify arithmetic, etc. - -func.func @dot_general(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [2], - rhs_contracting_dimensions = [1] - >, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// The iterations are (Batch Dim, LHS Other Dim, RHS Other dim, Contracting Dim) -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)> -// Output is the iterators excluding contracting -// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK: func @dot_general( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// Only contracting dims are reductions -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[ARG2]], %[[ARG3]] : f32 -// CHECK: %[[SUM:.*]] = arith.addf %[[ARG4]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[SUM]] : f32 -// CHECK: } -> tensor - -// ----- - -func.func @dot_general_unsigned(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [2], - rhs_contracting_dimensions = [1] - >, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @dot_general_unsigned( -// CHECK: linalg.generic -// CHECK-SAME: ins({{.*}} : tensor, tensor) -// CHECK-SAME: outs({{.*}} : tensor) - -// ----- - -func.func @dot_general_complex(%arg0: tensor>, - %arg1: tensor>) -> tensor> { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [1], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [2], - rhs_contracting_dimensions = [1] - >, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor>, tensor>) -> tensor> - func.return %0 : tensor> -} - -// CHECK-LABEL: func @dot_general_complex( -// CHECK: linalg.generic -// CHECK: complex.mul -// CHECK: complex.add - -// ----- - -func.func @dot_general_multiple_batch_dimensions(%arg0: tensor<3x4x2x4xi32>, - %arg1: tensor<3x4x3x2xi32>) -> tensor<3x4x4x3xi32> { - %0 = "stablehlo.dot_general"(%arg0, %arg1) { - dot_dimension_numbers = #stablehlo.dot< - lhs_batching_dimensions = [0, 1], - rhs_batching_dimensions = [0, 1], - lhs_contracting_dimensions = [2], - rhs_contracting_dimensions = [3]>, - precision_config = [#stablehlo, #stablehlo], - someattr - } : (tensor<3x4x2x4xi32>, tensor<3x4x3x2xi32>) -> tensor<3x4x4x3xi32> - return %0 : tensor<3x4x4x3xi32> -} - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d2)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> -// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> -// CHECK: func @dot_general_multiple_batch_dimensions -// CHECK-SAME: (%[[ARG0:.+]]: tensor<3x4x2x4xi32>, %[[ARG1:.+]]: tensor<3x4x3x2xi32>) -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3x4x2x4xi32>, tensor<3x4x3x2xi32>) -// CHECK-SAME: outs({{.*}} : tensor<3x4x4x3xi32>) -// CHECK-SAME: {someattr} - -// ----- - -func.func @dot_matmul(%arg0: tensor<2x3xf32>, - %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> { - %0 = "stablehlo.dot"(%arg0, %arg1) {someattr} - : (tensor<2x3xf32>, tensor<3x?xf32>) -> tensor<2x?xf32> - func.return %0 : tensor<2x?xf32> -} -// CHECK-LABEL: func @dot_matmul -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: {someattr} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xf32>, tensor<3x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xf32>) - -// ----- - -func.func @dot_matmul_complex(%arg0: tensor<2x3xcomplex>, - %arg1: tensor<3x?xcomplex>) -> tensor<2x?xcomplex> { - %0 = "stablehlo.dot"(%arg0, %arg1) {someattr} - : (tensor<2x3xcomplex>, tensor<3x?xcomplex>) -> tensor<2x?xcomplex> - func.return %0 : tensor<2x?xcomplex> -} -// CHECK-LABEL: func @dot_matmul_complex( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xcomplex>, %[[ARG1:.*]]: tensor<3x?xcomplex>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: {someattr} -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xcomplex>, tensor<3x?xcomplex>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xcomplex>) - -// ----- - -func.func @dot_matmul_i8_i8_i32(%arg0: tensor<2x3xi8>, - %arg1: tensor<3x?xi8>) -> tensor<2x?xi32> { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xi8>, - tensor<3x?xi8>) -> tensor<2x?xi32> - func.return %0 : tensor<2x?xi32> -} -// CHECK-LABEL: func @dot_matmul_i8_i8_i32( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi8>, %[[ARG1:.*]]: tensor<3x?xi8>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi8>, tensor<3x?xi8>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) - -// ----- - -func.func @dot_matmul_i16_i16_i32(%arg0: tensor<2x3xi16>, - %arg1: tensor<3x?xi16>) -> tensor<2x?xi32> { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xi16>, - tensor<3x?xi16>) -> tensor<2x?xi32> - func.return %0 : tensor<2x?xi32> -} -// CHECK-LABEL: func @dot_matmul_i16_i16_i32( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi16>, %[[ARG1:.*]]: tensor<3x?xi16>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi16>, tensor<3x?xi16>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) - -// ----- - -func.func @dot_matmul_i32_i32_i32(%arg0: tensor<2x3xi32>, - %arg1: tensor<3x?xi32>) -> tensor<2x?xi32> { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<2x3xi32>, - tensor<3x?xi32>) -> tensor<2x?xi32> - func.return %0 : tensor<2x?xi32> -} -// CHECK-LABEL: func @dot_matmul_i32_i32_i32( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xi32>, %[[ARG1:.*]]: tensor<3x?xi32>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) : tensor<2x?x -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3xi32>, tensor<3x?xi32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<2x?xi32>) - -// ----- - -func.func @dot_matvec(%arg0: tensor, - %arg1: tensor<3xf32>) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor, - tensor<3xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_matvec( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D0]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.matvec -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor<3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) - -// ----- - -func.func @dot_vecmat(%arg0: tensor<3xf32>, - %arg1: tensor<3x?xf32>) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<3xf32>, - tensor<3x?xf32>) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_vecmat( -// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32>, %[[ARG1:.*]]: tensor<3x?xf32>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]]) -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.vecmat -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<3xf32>, tensor<3x?xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) - -// ----- - -func.func @dot_dot(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_dot( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = tensor.empty() -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT]] -// CHECK: linalg.dot -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) - -// ----- - -func.func @dot_dot_unsigned(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: func @dot_dot_unsigned( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -// CHECK: %[[INIT:.*]] = tensor.empty() -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}}outs(%[[INIT]] -// CHECK: linalg.dot -// CHECK-SAME: ins(%{{.*}} : tensor, tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor) diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_gather.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_gather.mlir deleted file mode 100644 index e585d5963a33..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_gather.mlir +++ /dev/null @@ -1,325 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func.func @gather( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather(%operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xi32>) -> tensor<1x8x8xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array, - someattr - } : (tensor<1x4x8xi32>, tensor<1x8x2xi32>) -> tensor<1x8x8xi32> - func.return %res : tensor<1x8x8xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: outs(%[[INIT]] : tensor<1x8x8xi32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[IDX1]], %[[C0]]] : tensor<1x8x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[IDX1]], %[[C1]]] : tensor<1x8x2xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.minsi %[[CLAMP0]], %[[C0]] -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[IN1:.+]] = arith.minsi %[[CLAMP1]], %[[C3]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]], %[[IDX2]]] : tensor<1x4x8xi32> -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK-DAG: return %[[RES]] - -// ----- - -// CHECK-LABEL: func.func @gather_unsigned_index( -func.func @gather_unsigned_index( - %operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xui32>) - -> tensor<1x8x8xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array, - someattr - } : (tensor<1x4x8xi32>, tensor<1x8x2xui32>) -> tensor<1x8x8xi32> - func.return %res : tensor<1x8x8xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK: %[[S0_INT:.+]] = tensor.extract {{.*}}[{{.*}}, %[[C0]]] -// CHECK: arith.index_castui %[[S0_INT]] : i32 to index -// CHECK: %[[S1_INT:.+]] = tensor.extract {{.*}}[{{.*}}, %[[C1]]] -// CHECK: arith.index_castui %[[S1_INT]] : i32 to index - -// ----- - -// CHECK-LABEL: func @gather_unsigned( -func.func @gather_unsigned(%operand : tensor<1x4x8xui32>, %start_indices : tensor<1x8x2xi32>) -> tensor<1x8x8xui32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor<1x4x8xui32>, tensor<1x8x2xi32>) -> tensor<1x8x8xui32> - func.return %res : tensor<1x8x8xui32> -} - -// CHECK: linalg.generic -// CHECK-SAME: outs(%{{.*}} : tensor<1x8x8xi32>) - -// ----- - -// CHECK-LABEL: func.func @gather_no_collapse( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather_no_collapse(%operand : tensor<6x3xi32>, %start_indices : tensor<5x2xi32>) -> tensor<5x4x2xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor<6x3xi32>, tensor<5x2xi32>) -> tensor<5x4x2xi32> - func.return %res : tensor<5x4x2xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x4x2xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<5x4x2xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C0]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C1]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[C2]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX1]] : index -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP1_1:.+]] = arith.minsi %[[CLAMP1]], %[[C1]] -// CHECK-DAG: %[[IN1:.+]] = arith.addi %[[CLAMP1_1]], %[[IDX2]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]]] : tensor<6x3xi32> -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - - -// ----- - -func.func @gather_max_offset(%operand : tensor, %start_indices : tensor<5x2xi32>) -> tensor<2x3x4x5xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [0, 1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor, tensor<5x2xi32>) -> tensor<2x3x4x5xi32> - func.return %res : tensor<2x3x4x5xi32> -} - -// CHECK-LABEL: func @gather_max_offset( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2x3x4x5xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<2x3x4x5xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX3]], %[[C0]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX3]], %[[C1]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[L0:.+]] = arith.subi %[[D0]], %[[C2]] -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[L0]] : index -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] -// CHECK-DAG: %[[L1:.+]] = arith.subi %[[D1]], %[[C3]] -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP1_1:.+]] = arith.minsi %[[CLAMP1]], %[[L1]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX0]] : index -// CHECK-DAG: %[[IN1:.+]] = arith.addi %[[CLAMP1_1]], %[[IDX1]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]], %[[IDX2]]] : tensor -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - -// ----- - -func.func @gather_reorder_start_index(%operand : tensor<6x3x2x7xi32>, %start_indices : tensor<5x4xi32>) -> tensor<5x2x4xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [0, 2], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [3, 1, 2, 0] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor<6x3x2x7xi32>, tensor<5x4xi32>) -> tensor<5x2x4xi32> - func.return %res : tensor<5x2x4xi32> -} - -// CHECK-LABEL: func @gather_reorder_start_index( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[C5:.+]] = arith.constant 5 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x2x4xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<5x2x4xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C0]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[S1_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C1]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S1:.+]] = arith.index_cast %[[S1_INT]] : i32 to index -// CHECK-DAG: %[[S2_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C2]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S2:.+]] = arith.index_cast %[[S2_INT]] : i32 to index -// CHECK-DAG: %[[S3_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX0]], %[[C3]]] : tensor<5x4xi32> -// CHECK-DAG: %[[S3:.+]] = arith.index_cast %[[S3_INT]] : i32 to index -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[C3]] -// CHECK-DAG: %[[CLAMP1:.+]] = arith.maxsi %[[S1]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP1_1:.+]] = arith.minsi %[[CLAMP1]], %[[C1]] -// CHECK-DAG: %[[CLAMP2:.+]] = arith.maxsi %[[S2]], %[[C0]] : index -// CHECK-DAG: %[[IN2:.+]] = arith.minsi %[[CLAMP2]], %[[C1]] -// CHECK-DAG: %[[CLAMP3:.+]] = arith.maxsi %[[S3]], %[[C0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.minsi %[[CLAMP3]], %[[C5]] -// CHECK-DAG: %[[IN1:.+]] = arith.addi %[[CLAMP1_1]], %[[IDX1]] : index -// CHECK-DAG: %[[IN3:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX2]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IN1]], %[[IN2]], %[[IN3]]] : tensor<6x3x2x7xi32> -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - -// ----- - -// CHECK-LABEL: func.func @gather_implicit_trailing_dim( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather_implicit_trailing_dim(%operand : tensor, %start_indices : tensor<5x2xi32>) -> tensor<3x4x5x2xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 2, - offset_dims = [0, 1], - start_index_map = [0] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor, tensor<5x2xi32>) -> tensor<3x4x5x2xi32> - func.return %res : tensor<3x4x5x2xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<3x4x5x2xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<3x4x5x2xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX2]], %[[IDX3]]] : tensor<5x2xi32> -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[L0:.+]] = arith.subi %[[D0]], %[[C3]] -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[L0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX0]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IDX1]]] : tensor -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] - -// ----- - -// CHECK-LABEL: func.func @gather_non_static( -// CHECK-SAME: %[[OPERAND:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[START_INDICES:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @gather_non_static(%operand : tensor, %start_indices : tensor) -> tensor<3x4x?xi32> { - %res = "stablehlo.gather"(%operand, %start_indices) { - dimension_numbers = #stablehlo.gather< - collapsed_slice_dims = [], - index_vector_dim = 1, - offset_dims = [0, 1], - start_index_map = [0] - >, - indices_are_sorted = false, - slice_sizes = array - } : (tensor, tensor) -> tensor<3x4x?xi32> - func.return %res : tensor<3x4x?xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 -// CHECK-DAG: %[[DYN_DIM:.+]] = tensor.dim %[[START_INDICES]], %[[C0]] -// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[DYN_DIM]]) : tensor<3x4x?xi32> -// CHECK: %[[RES:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor<3x4x?xi32>) { -// CHECK: ^bb0 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 -// CHECK-DAG: %[[S0_INT:.+]] = tensor.extract %[[START_INDICES]][%[[IDX2]], %[[C0]]] : tensor -// CHECK-DAG: %[[S0:.+]] = arith.index_cast %[[S0_INT]] : i32 to index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] -// CHECK-DAG: %[[L0:.+]] = arith.subi %[[D0]], %[[C3]] -// CHECK-DAG: %[[CLAMP0:.+]] = arith.maxsi %[[S0]], %[[C0]] : index -// CHECK-DAG: %[[CLAMP0_1:.+]] = arith.minsi %[[CLAMP0]], %[[L0]] : index -// CHECK-DAG: %[[IN0:.+]] = arith.addi %[[CLAMP0_1]], %[[IDX0]] : index -// CHECK: %[[Y:.+]] = tensor.extract %[[OPERAND]][%[[IN0]], %[[IDX1]]] : tensor -// CHECK: linalg.yield %[[Y]] : i32 -// CHECK: return %[[RES]] diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_pointwise.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_pointwise.mlir deleted file mode 100644 index e1421652e319..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_pointwise.mlir +++ /dev/null @@ -1,1489 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// RUN: iree-opt %s --iree-stablehlo-to-linalg="enable-primitive-ops=true" \ -// RUN: --split-input-file --canonicalize | \ -// RUN: FileCheck %s --check-prefix=CHECK-PRIMITIVE - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @float_add -// CHECK-PRIMITIVE-LABEL: func @float_add -func.func @float_add(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK-SAME: {someattr} - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]] - // CHECK: linalg.yield %[[RESULT]] - - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.addf - %0 = "stablehlo.add"(%lhs, %rhs) {someattr} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @float_add_dynamic_encoding -// CHECK-PRIMITIVE-LABEL: func @float_add_dynamic_encoding -func.func @float_add_dynamic_encoding( - %lhs: tensor<2x?xf32, #stablehlo.type_extensions>, - %rhs: tensor<2x?xf32, #stablehlo.type_extensions>) - -> tensor<2x?xf32, #stablehlo.type_extensions> { - // CHECK: linalg.generic - // CHECK: arith.addf - // CHECK: linalg.yield - - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.addf - %0 = "stablehlo.add"(%lhs, %rhs) - : (tensor<2x?xf32, #stablehlo.type_extensions>, - tensor<2x?xf32, #stablehlo.type_extensions>) - -> tensor<2x?xf32, #stablehlo.type_extensions> - func.return %0 : tensor<2x?xf32, #stablehlo.type_extensions> -} - -// ----- - -// CHECK-LABEL: integer_add -// CHECK-PRIMITIVE-LABEL: integer_add -func.func @integer_add(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: addi - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: addi - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: complex_add -// CHECK-PRIMITIVE-LABEL: complex_add -func.func @complex_add(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.add - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.add - %0 = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex>, - tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @complex_atan2 -// CHECK-PRIMITIVE-LABEL: func @complex_atan2 -func.func @complex_atan2(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.atan2"(%lhs, %rhs) - : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.atan2 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.atan2 - func.return %tensor_result : tensor<2x2xcomplex> -} - - -// ----- - -// CHECK-LABEL: func @float_mul -// CHECK-PRIMITIVE-LABEL: func @float_mul -func.func @float_mul(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: mulf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: mulf - %0 = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_mul -// CHECK-PRIMITIVE-LABEL: func @integer_mul -func.func @integer_mul(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: muli - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: muli - %0 = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @complex_mul -// CHECK-PRIMITIVE-LABEL: func @complex_mul -func.func @complex_mul(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.mul - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.mul - %0 = "stablehlo.multiply"(%lhs, %rhs) - : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_remainder -// CHECK-PRIMITIVE-LABEL: func @float_remainder -func.func @float_remainder(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: remf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: remf - %0 = "stablehlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_remainder -// CHECK-PRIMITIVE-LABEL: func @integer_remainder -func.func @integer_remainder(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: arith.remsi - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.remsi - %0 = "stablehlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @population_count_integer -// CHECK-PRIMITIVE-LABEL: func @population_count_integer -func.func @population_count_integer(%lhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: math.ctpop - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.ctpop - %0 = "stablehlo.popcnt"(%lhs) : (tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @complex_sqrt -// CHECK-PRIMITIVE-LABEL: func @complex_sqrt -func.func @complex_sqrt(%operand: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.sqrt"(%operand) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.sqrt - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sqrt - func.return %tensor_result : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_rsqrt -// CHECK-PRIMITIVE-LABEL: func @float_rsqrt -func.func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { - %tensor_result = "stablehlo.rsqrt"(%operand) - : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: linalg.generic - // CHECK: rsqrt - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: rsqrt - func.return %tensor_result : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_rsqrt -// CHECK-PRIMITIVE-LABEL: func @complex_rsqrt -func.func @complex_rsqrt(%operand: tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.rsqrt"(%operand) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.rsqrt - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.rsqrt - func.return %tensor_result : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_cbrt -// CHECK-PRIMITIVE-LABEL: func @float_cbrt -func.func @float_cbrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { - %tensor_result = "stablehlo.cbrt"(%operand) - : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[RESULT:.+]] = math.cbrt %[[IN]] - // CHECK: linalg.yield %[[RESULT]] - - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.cbrt - func.return %tensor_result : tensor<2x2xf32> -} - -// ----- - - -// CHECK-LABEL: func @float_sub -// CHECK-PRIMITIVE-LABEL: func @float_sub -func.func @float_sub(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: subf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: subf - %0 = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_sub -// CHECK-PRIMITIVE-LABEL: func @integer_sub -func.func @integer_sub(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: subi - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: subi - %0 = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: complex_sub -// CHECK-PRIMITIVE-LABEL: complex_sub -func.func @complex_sub(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.sub - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sub - %0 = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex>, - tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_abs -// CHECK-PRIMITIVE-LABEL: func @float_abs -func.func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK-SAME: {someattr} - // CHECK: math.absf - // CHECK-PRIMITIVE: linalg.map { math.absf } - // CHECK-PRIMITIVE-SAME: ins( - // CHECK-PRIMITIVE-SAME: outs( - // CHECK-PRIMITIVE-SAME: {someattr} - %0 = "stablehlo.abs"(%arg0) {someattr} : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @float_exp -// CHECK-PRIMITIVE-LABEL: func @float_exp -func.func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: exp - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: exp - %0 = "stablehlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_exp -func.func @complex_exp(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.exp - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.exp - %0 = "stablehlo.exponential"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_expm1 -func.func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: expm1 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: expm1 - %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_expm1 -func.func @complex_expm1(%arg0: tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.expm1 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.expm1 - %0 = "stablehlo.exponential_minus_one"(%arg0) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_log -func.func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.log - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.log - %0 = "stablehlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_log -func.func @complex_log(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.log - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.log - %0 = "stablehlo.log"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_log1p -// CHECK-PRIMITIVE-LABEL: func @float_log1p -func.func @float_log1p(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.log1p - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.log1p - %0 = "stablehlo.log_plus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_log1p -// CHECK-PRIMITIVE-LABEL: func @complex_log1p -func.func @complex_log1p(%arg0: tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.log1p - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.log1p - %0 = "stablehlo.log_plus_one"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_logistic -// CHECK-PRIMITIVE-LABEL: func @float_logistic -func.func @float_logistic(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: %[[C1:.*]] = arith.constant 1.{{.*}}e+00 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32): - // CHECK: %[[NEG_ARG:.*]] = arith.negf %[[ARG]] - // CHECK: %[[EXP_NEG_ARG:.*]] = math.exp %[[NEG_ARG]] - // CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = arith.addf %[[EXP_NEG_ARG]], %[[C1]] - // CHECK: %[[RESULT:.*]] = arith.divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]] - // CHECK: linalg.yield %[[RESULT]] - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: arith.negf - // CHECK-PRIMITIVE: math.exp - // CHECK-PRIMITIVE: arith.addf - // CHECK-PRIMITIVE: arith.divf - %0 = "stablehlo.logistic"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_logistic -func.func @complex_logistic(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[ARG:.*]]: complex, %{{.*}}: complex): - // CHECK: %[[NEG_ARG:.*]] = complex.neg %[[ARG]] - // CHECK: %[[EXP_NEG_ARG:.*]] = complex.exp %[[NEG_ARG]] - // CHECK: %[[CC1:.*]] = complex.create %[[C1]], %[[C0]] : complex - // CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = complex.add %[[EXP_NEG_ARG]], %[[CC1]] - // CHECK: %[[RESULT:.*]] = complex.div %[[CC1]], %[[ONE_ADD_EXP_NEG_ARG]] - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.logistic"(%arg0) : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_ceil -// CHECK-PRIMITIVE-LABEL: func @float_ceil -func.func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.ceil - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.ceil - %0 = "stablehlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @floor -// CHECK-PRIMITIVE-LABEL: func @floor -func.func @floor(%input: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.floor - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.floor - %0 = "stablehlo.floor"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @float_neg -// CHECK-PRIMITIVE-LABEL: func @float_neg -func.func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: negf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: negf - %0 = "stablehlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_neg -// CHECK-PRIMITIVE-LABEL: func @complex_neg -func.func @complex_neg(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.neg - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.neg - %0 = "stablehlo.negate"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @complex_sign -// CHECK-PRIMITIVE-LABEL: func @complex_sign -func.func @complex_sign( - %arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.sign - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sign - %0 = "stablehlo.sign"(%arg0) : (tensor<2x2xcomplex>) - -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_tanh -// CHECK-PRIMITIVE-LABEL: func @float_tanh -func.func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: tanh - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: tanh - %0 = "stablehlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_tanh -// CHECK-PRIMITIVE-LABEL: func @complex_tanh -func.func @complex_tanh(%operand: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - %tensor_result = "stablehlo.tanh"(%operand) - : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK: complex.tanh - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.tanh - func.return %tensor_result : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @integer_and -// CHECK-PRIMITIVE-LABEL: func @integer_and -func.func @integer_and(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: and - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: and - %0 = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @integer_or -// CHECK-PRIMITIVE-LABEL: func @integer_or -func.func @integer_or(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: or - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: or - %0 = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @integer_xor -// CHECK-PRIMITIVE-LABEL: func @integer_xor -func.func @integer_xor(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: xor - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: xor - %0 = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @count_leading_zeros -// CHECK-PRIMITIVE-LABEL: func @count_leading_zeros -func.func @count_leading_zeros(%lhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: math.ctlz - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.ctlz - %0 = "stablehlo.count_leading_zeros"(%lhs) : (tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: unsigned_convert -func.func @unsigned_convert(%in: tensor<2x2xui32>) -> tensor<2x2xui64> { - // CHECK: linalg.generic - // CHECK: arith.extui - %0 = "stablehlo.convert"(%in) : (tensor<2x2xui32>) -> tensor<2x2xui64> - func.return %0 : tensor<2x2xui64> -} - -// ----- - -// CHECK-LABEL: func @float_cmp -// CHECK-PRIMITIVE-LABEL: func @float_cmp -func.func @float_cmp(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpf oeq, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.cmpf - -// ----- - -// CHECK-LABEL: func @float_cmp_ne -// CHECK-PRIMITIVE-LABEL: func @float_cmp_ne -func.func @float_cmp_ne(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpf une, %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.cmpf - -// ----- - -// CHECK-LABEL: func @float_cmp_totalorder -// CHECK-PRIMITIVE-LABEL: func @float_cmp_totalorder -func.func @float_cmp_totalorder(%lhs: tensor<2x2xbf16>, - %rhs: tensor<2x2xbf16>) -> (tensor<2x2xi1>) { - %0 = "stablehlo.compare"(%lhs, %rhs) { - comparison_direction = #stablehlo, - compare_type = #stablehlo - } : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i16 -// CHECK-DAG: %[[C32767:.*]] = arith.constant 32767 : i16 -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: bf16, %[[RHS_IN:.*]]: bf16, %{{.*}}: i1): -// CHECK-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16 -// CHECK-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16 -// CHECK-NEXT: %[[LHS_SUB:.*]] = arith.subi %[[C32767]], %[[LHS_INT]] : i16 -// CHECK-NEXT: %[[LHS_SELECT:.*]] = arith.select %[[LHS_CMP]], %[[LHS_SUB]], %[[LHS_INT]] : i16 -// CHECK-NEXT: %[[RHS_INT:.*]] = arith.bitcast %[[RHS_IN]] : bf16 to i16 -// CHECK-NEXT: %[[RHS_CMP:.*]] = arith.cmpi slt, %[[RHS_INT]], %[[C0]] : i16 -// CHECK-NEXT: %[[RHS_SUB:.*]] = arith.subi %[[C32767]], %[[RHS_INT]] : i16 -// CHECK-NEXT: %[[RHS_SELECT:.*]] = arith.select %[[RHS_CMP]], %[[RHS_SUB]], %[[RHS_INT]] : i16 -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpi slt, %[[LHS_SELECT]], %[[RHS_SELECT]] : i16 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// CHECK-PRIMITIVE-DAG: %[[C0:.*]] = arith.constant 0 : i16 -// CHECK-PRIMITIVE-DAG: %[[C32767:.*]] = arith.constant 32767 : i16 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE-SAME: ins( -// CHECK-PRIMITIVE-SAME: outs( -// CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16) { -// CHECK-PRIMITIVE-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16 -// CHECK-PRIMITIVE-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[LHS_SUB:.*]] = arith.subi %[[C32767]], %[[LHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[LHS_SELECT:.*]] = arith.select %[[LHS_CMP]], %[[LHS_SUB]], %[[LHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_INT:.*]] = arith.bitcast %[[RHS_IN]] : bf16 to i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_CMP:.*]] = arith.cmpi slt, %[[RHS_INT]], %[[C0]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_SUB:.*]] = arith.subi %[[C32767]], %[[RHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RHS_SELECT:.*]] = arith.select %[[RHS_CMP]], %[[RHS_SUB]], %[[RHS_INT]] : i16 -// CHECK-PRIMITIVE-NEXT: %[[RESULT:.*]] = arith.cmpi slt, %[[LHS_SELECT]], %[[RHS_SELECT]] : i16 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[RESULT]] : i1 - -// ----- - -// CHECK-LABEL: func @int_cmp -// CHECK-PRIMITIVE-LABEL: func @int_cmp -func.func @int_cmp(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) - func.return %0 : tensor<2x2xi1> -} -// CHECK: tensor.empty() : tensor<2x2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.cmpi slt, %[[LHS_IN]], %[[RHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.cmpi - -// ----- - -// CHECK-LABEL: func @complex_cmp_eq -// CHECK-PRIMITIVE-LABEL: func @complex_cmp_eq -func.func @complex_cmp_eq(%lhs: tensor<2xcomplex>, - %rhs: tensor<2xcomplex>) -> tensor<2xi1> { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xi1>) - func.return %0 : tensor<2xi1> -} -// CHECK: tensor.empty() : tensor<2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): -// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: complex.eq - -// ----- - -// CHECK-LABEL: func @complex_cmp_neq -// CHECK-PRIMITIVE-LABEL: func @complex_cmp_neq -func.func @complex_cmp_neq(%lhs: tensor<2xcomplex>, - %rhs: tensor<2xcomplex>) -> tensor<2xi1> { - %0 = "stablehlo.compare"(%lhs, %rhs) {comparison_direction = #stablehlo} - : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xi1>) - func.return %0 : tensor<2xi1> -} -// CHECK: tensor.empty() : tensor<2xi1> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex, %[[RHS_IN:.*]]: complex, %[[RESULT_OUT:.*]]: i1): -// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: complex.neq - -// ----- - -// CHECK-LABEL: func @float_cos -// CHECK-PRIMITIVE-LABEL: func @float_cos -func.func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.cos - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.cos - %0 = "stablehlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_cos -// CHECK-PRIMITIVE-LABEL: func @complex_cos -func.func @complex_cos(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.cos - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.cos - %0 = "stablehlo.cosine"(%arg0) : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @float_sin -// CHECK-PRIMITIVE-LABEL: func @float_sin -func.func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: math.sin - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.sin - %0 = "stablehlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_sin -// CHECK-PRIMITIVE-LABEL: func @complex_sin -func.func @complex_sin(%arg0: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.sin - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: complex.sin - %0 = "stablehlo.sine"(%arg0) : (tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK-LABEL: func @is_finte -// CHECK-PRIMITIVE-LABEL: func @is_finte -func.func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { - %0 = "stablehlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> - func.return %0 : tensor<2x2xi1> -} -// CHECK: %[[POS_INF:.+]] = arith.constant 0x7F800000 : f32 -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32 -// CHECK-NEXT: %[[ABS_X:.+]] = math.absf %[[OPERAND_IN]] : f32 -// CHECK-NEXT: %[[RESULT:.+]] = arith.cmpf one, %[[ABS_X]], %[[POS_INF]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: math.absf -// CHECK-PRIMITIVE: arith.cmpf - -// ----- - -// CHECK-LABEL: func @round_nearest_even -// CHECK-PRIMITIVE-LABEL: func @round_nearest_even -func.func @round_nearest_even(%val: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[ROUND:.+]] = math.roundeven %[[IN]] - // CHECK: linalg.yield %[[ROUND]] - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.roundeven - %0 = "stablehlo.round_nearest_even"(%val) : (tensor<2x2xf32>) -> (tensor<2x2xf32>) - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @round -// CHECK-PRIMITIVE-LABEL: func @round -func.func @round(%val: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[ROUND:.+]] = math.round %[[IN]] - // CHECK: linalg.yield %[[ROUND]] - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.round - %0 = "stablehlo.round_nearest_afz"(%val) : (tensor<2x2xf32>) -> (tensor<2x2xf32>) - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @select -func.func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "stablehlo.select"(%pred, %lhs, %rhs) - : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) - func.return %0 : tensor<2x2xf32> -} -// CHECK: tensor.empty() : tensor<2x2xf32> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE-LABEL: func @select -// CHECK-PRIMITIVE: tensor.empty() : tensor<2x2xf32> -// CHECK-PRIMITIVE: linalg.map { arith.select } -// CHECK-PRIMITIVE-SAME: ins( -// CHECK-PRIMITIVE-SAME: outs( - -// ----- - -// CHECK-DAG: #[[SCALAR_MAP:.*]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @select_scalar_pred_dyn -// CHECK-SAME: (%[[PRED:.*]]: tensor, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>) -func.func @select_scalar_pred_dyn(%pred : tensor, %lhs: tensor<2x?xf32>, %rhs: tensor<2x?xf32>) -> tensor<2x?xf32> { - %0 = "stablehlo.select"(%pred, %lhs, %rhs) {someattr} : (tensor, tensor<2x?xf32>, tensor<2x?xf32>) -> (tensor<2x?xf32>) - func.return %0 : tensor<2x?xf32> -} -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[LHS]], %[[C1]] -// CHECK-DAG: %[[DST:.*]] = tensor.empty(%[[DIM]]) -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[SCALAR_MAP]], #[[ID_MAP]], #[[ID_MAP]], #[[ID_MAP]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[PRED]], %[[LHS]], %[[RHS]] : tensor, tensor<2x?xf32>, tensor<2x?xf32>) -// CHECK-SAME: outs(%[[DST]] : tensor<2x?xf32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[PRED_:.*]]: i1, %[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %{{.*}}: f32): -// CHECK: %[[RES:.*]] = arith.select %[[PRED_]], %[[LHS_]], %[[RHS_]] : f32 -// CHECK: linalg.yield %[[RES]] - -// CHECK-PRIMITIVE-LABEL: func @select_scalar_pred_dyn -// CHECK-PRIMITIVE-SAME: (%[[PRED:.*]]: tensor, %[[LHS:.*]]: tensor<2x?xf32>, %[[RHS:.*]]: tensor<2x?xf32>) -// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-PRIMITIVE-DAG: %[[DIM:.*]] = tensor.dim %[[LHS]], %[[C1]] -// CHECK-PRIMITIVE-DAG: %[[DST:.*]] = tensor.empty(%[[DIM]]) -// CHECK-PRIMITIVE-DAG: %[[PRED_ELEM:.*]] = tensor.extract %[[PRED]] -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>) -// CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>) -// CHECK-PRIMITIVE-SAME: {someattr} -// CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) { -// CHECK-PRIMITIVE: %[[RES:.*]] = arith.select %[[PRED_ELEM]], %[[LHS_]], %[[RHS_]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[RES]] - -// ----- - -// CHECK-LABEL: func @select_mixed -func.func @select_mixed(%pred: tensor<2x?xi1>, %lhs: tensor, - %rhs: tensor<2x2xf32>) -> tensor { - %0 = "stablehlo.select"(%pred, %lhs, %rhs) - : (tensor<2x?xi1>, tensor, tensor<2x2xf32>) -> (tensor) - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast_convert -func.func @bitcast_convert(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> - func.return %result : tensor<2x2xf32> -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK-LABEL: func @bitcast_convert_dynamic -func.func @bitcast_convert_dynamic(%input: tensor) -> tensor { - %result = "stablehlo.bitcast_convert"(%input) : (tensor) -> tensor - func.return %result : tensor -} -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.bitcast %[[OPERAND_IN]] : i32 to f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.bitcast - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @bitcast_convert_expand -func.func @bitcast_convert_expand(%input: tensor<6xi32>) -> tensor<6x4xi8> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<6xi32>) -> tensor<6x4xi8> - func.return %result : tensor<6x4xi8> -} - -// CHECK: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: tensor.empty() : tensor<6x4xi8> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "parallel"]} -// CHECK: ^bb0(%[[IN:.*]]: i32, %[[OUT:.*]]: i8): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i32 -// CHECK: %[[SHIFT:.*]] = arith.shrui %[[IN]], %[[AMT]] : i32 -// CHECK: %[[TRUNC:.*]] = arith.trunci %[[SHIFT]] : i32 to i8 -// CHECK: linalg.yield %[[TRUNC]] : i8 -// CHECK: } -> tensor<6x4xi8> -// CHECK: return %[[RESULT]] : tensor<6x4xi8> - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @bitcast_convert_contract -func.func @bitcast_convert_contract(%input: tensor<7x4xi8>) -> tensor<7xi32> { - %result = "stablehlo.bitcast_convert"(%input) : (tensor<7x4xi8>) -> tensor<7xi32> - func.return %result : tensor<7xi32> -} -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<7xi32> -// CHECK: linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<7xi32>) -> tensor<7xi32> -// CHECK: %[[RESULT:.*]] = linalg.generic { -// CHECK: indexing_maps = [#[[MAP0]], #[[MAP1]]], -// CHECK: iterator_types = ["parallel", "reduction"]} -// CHECK: ^bb0(%[[IN:.*]]: i8, %[[OUT:.*]]: i32): -// CHECK: %[[IOTA:.*]] = linalg.index 1 : index -// CHECK: %[[IOTA_CASTED:.*]] = arith.index_cast %[[IOTA]] : index to i32 -// CHECK: %[[AMT:.*]] = arith.muli %[[IOTA_CASTED]], %[[C8]] : i3 -// CHECK: %[[EXT:.*]] = arith.extui %[[IN]] : i8 to i32 -// CHECK: %[[SHIFT:.*]] = arith.shli %[[EXT]], %[[AMT]] : i32 -// CHECK: %[[OR:.*]] = arith.ori %[[SHIFT]], %[[OUT]] : i32 -// CHECK: linalg.yield %[[OR]] : i32 -// CHECK: } -> tensor<7xi32> -// CHECK: return %[[RESULT]] : tensor<7xi32> - -// ----- - -// CHECK-LABEL: signed_divide -func.func @signed_divide(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK-DAG: %[[VAL_7:.*]] = arith.constant -1 : i32 - // CHECK-DAG: %[[VAL_8:.*]] = arith.constant -2147483648 : i32 - // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 1 : i32 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32): - // CHECK: %[[VAL_11:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_9]] : i32 - // CHECK: %[[VAL_13:.*]] = arith.cmpi eq, %[[VAL_4]], %[[VAL_8]] : i32 - // CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_7]] : i32 - // CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_13]], %[[VAL_15]] : i1 - // CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_11]], %[[VAL_16]] : i1 - // CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_10]], %[[VAL_5]] : i32 - // CHECK: %[[VAL_19:.*]] = arith.divsi %[[VAL_4]], %[[VAL_18]] : i32 - // CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_16]], %[[VAL_8]], %[[VAL_19]] : i32 - // CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_11]], %[[VAL_7]], %[[VAL_20]] : i32 - // CHECK: linalg.yield %[[VAL_21]] : i32 - %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: unsigned_divide -func.func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { - // CHECK-DAG: %[[VAL_9:.*]] = arith.constant -1 : i32 - // CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[VAL_12:.*]] = arith.constant 1 : i32 - // CHECK: linalg.generic - // CHECK: ^bb0(%[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32): - // CHECK: %[[VAL_13:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_11]] : i32 - // CHECK: %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_12]], %[[VAL_7]] : i32 - // CHECK: %[[VAL_15:.*]] = arith.divui %[[VAL_6]], %[[VAL_14]] : i32 - // CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_13]], %[[VAL_9]], %[[VAL_15]] : i32 - // CHECK: linalg.yield %[[VAL_16]] : i32 - %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> - func.return %0 : tensor<2x2xui32> -} - -// ----- - -// CHECK-LABEL: complex_divide -func.func @complex_divide(%lhs: tensor<2xcomplex>, - %rhs: tensor<2xcomplex>) -> tensor<2xcomplex> { - // CHECK: linalg.generic - // CHECK: complex.div - %0 = "stablehlo.divide"(%lhs, %rhs) : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xcomplex> - func.return %0 : tensor<2xcomplex> -} - -// ----- - -func.func @shift_left(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %result = "stablehlo.shift_left"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %result : tensor<2x2xi32> -} -// CHECK-LABEL: func @shift_left -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 -// CHECK-DAG: %[[BITS:.*]] = arith.constant 32 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): -// CHECK-DAG: %[[SHIFT:.*]] = arith.shli %[[LHS]], %[[RHS]] : i32 -// CHECK-DAG: %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]] -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[ZERO]] -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -func.func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %result : tensor<2x2xi32> -} -// CHECK-LABEL: func @shift_right_arithmetic -// CHECK-DAG: %[[BITS:.*]] = arith.constant 32 -// CHECK-DAG: %[[MAX_SHIFT:.*]] = arith.constant 31 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): -// CHECK-DAG: %[[SHIFT:.*]] = arith.shrsi %[[LHS]], %[[RHS]] : i32 -// CHECK-DAG: %[[MAX_SHIFTED:.*]] = arith.shrsi %[[LHS]], %[[MAX_SHIFT]] : i32 -// CHECK-DAG: %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]] -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[MAX_SHIFTED]] -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -func.func @shift_right_logical(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %result = "stablehlo.shift_right_logical"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %result : tensor<2x2xi32> -} -// CHECK-LABEL: func @shift_right_logical -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 -// CHECK-DAG: %[[BITS:.*]] = arith.constant 32 -// CHECK: tensor.empty -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32, %{{.*}}: i32): -// CHECK-DAG: %[[SHIFT:.*]] = arith.shrui %[[LHS]], %[[RHS]] : i32 -// CHECK-DAG: %[[NOT_SATURATING:.*]] = arith.cmpi ult, %[[RHS]], %[[BITS]] -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[NOT_SATURATING]], %[[SHIFT]], %[[ZERO]] -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: func @einsum_basic -func.func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { - %0 = "stablehlo.einsum"(%arg0, %arg1) {einsum_config = "ijk,ikm->ijm", someattr}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> - func.return %0 : tensor<3x4x6xf32> -} -// CHECK-SAME: (%[[LHS:.*]]: tensor<3x4x5xf32>, %[[RHS:.*]]: tensor<3x5x6xf32>) -// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<3x4x6xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INIT]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel"] -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : tensor<3x4x5xf32>, tensor<3x5x6xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<3x4x6xf32>) -// CHECK-SAME: {someattr} -// CHECK: ^bb0(%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32, %[[OUT_:.*]]: f32): -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_]], %[[RHS_]] : f32 -// CHECK: %[[RES:.*]] = arith.addf %[[OUT_]], %[[MUL]] : f32 -// CHECK: linalg.yield %[[RES]] - -// ----- - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @float_pow -func.func @float_pow(%lhs: tensor<2x2xf32>, - %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: linalg.generic - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = math.powf %[[ARG0]], %[[ARG1]] - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xf32>, - tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @complex_pow -func.func @complex_pow(%lhs: tensor<2x2xcomplex>, - %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { - // CHECK: linalg.generic - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: complex - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: complex - // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = complex.pow %[[ARG0]], %[[ARG1]] - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xcomplex>, - tensor<2x2xcomplex>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} - -// ----- - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func @integer_pow -func.func @integer_pow(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: linalg.generic - // CHECK: ^{{[a-z0-9_]*}} - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 - // CHECK: %[[FOR_RESULT:[a-zA-Z0-9_]*]]:3 = scf.for {{.*}} to %c6 step %c1 - // CHECK-SAME: iter_args( - // CHECK-SAME: %[[ITER0:.*]] = %c1 - // CHECK-SAME: %[[ITER1:.*]] = %[[ARG0]], - // CHECK-SAME: %[[ITER2:.*]] = %[[ARG1]] - // CHECK-SAME: ) -> (i32, i32, i32) { - // CHECK: %[[AND:[a-zA-Z0-9_]*]] = arith.andi %[[ITER2]], %c1 - // CHECK: %[[COND:[a-zA-Z0-9_]*]] = arith.cmpi eq, %[[AND]], %c1 - // CHECK: %[[MUL:[a-zA-Z0-9_]*]] = arith.muli %[[ITER0]], %[[ITER1]] - // CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = arith.select %[[COND]], %[[MUL]], %[[ITER0]] - // CHECK: %[[BASE:[a-zA-Z0-9_]*]] = arith.muli %[[ITER1]], %[[ITER1]] - // CHECK: %[[EXP:[a-zA-Z0-9_]*]] = arith.shrui %[[ITER2]], %c1 - // CHECK: scf.yield %[[ACCUM]], %[[BASE]], %[[EXP]] - // CHECK: %[[RHS_PARITY:.*]] = arith.remsi %[[ARG1]], %c2 - // CHECK: %[[RHS_EVEN:.*]] = arith.cmpi eq, %[[RHS_PARITY]], %c0 - // CHECK: %[[RHS_NEG:.*]] = arith.cmpi slt, %[[ARG1]], %c0 - // CHECK: %[[LHS_ONE:.*]] = arith.cmpi eq, %[[ARG0]], %c1 - // CHECK: %[[LHS_NEG_ONE:.*]] = arith.cmpi eq, %[[ARG0]], %c-1 - // CHECK: %[[VAL5:.*]] = arith.extui %[[LHS_ONE]] : i1 to i32 - // CHECK: %[[VAL6:.*]] = arith.select %[[RHS_EVEN]], %c1{{.*}}, %c-1 - // CHECK: %[[VAL7:.*]] = arith.select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]] - // CHECK: %[[RESULT:.*]] = arith.select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]#0 - // CHECK: linalg.yield %[[RESULT]] - %0 = "stablehlo.power"(%lhs, %rhs) : (tensor<2x2xi32>, - tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: @real_real -// CHECK-SAME: (%[[ARG0:.*]]: -func.func @real_real(%arg0: tensor) -> tensor { - %1 = "stablehlo.real"(%arg0) : (tensor) -> (tensor) - // CHECK: return %[[ARG0]] - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @imag_real -func.func @imag_real(%arg0: tensor) -> tensor { - %1 = "stablehlo.imag"(%arg0) : (tensor) -> (tensor) - // CHECK: %[[CST:.*]] = arith.constant 0 - // CHECK: linalg.generic - // CHECK: yield %[[CST]] - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: func @minf -func.func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = "stablehlo.minimum"(%lhs, %rhs) {someattr} - : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} -// CHECK: tensor.empty() : tensor<2x2xf32> -// CHECK: linalg.generic -// CHECK-SAME: {someattr} -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %{{.*}}: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.minimumf %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.minimumf - -// ----- - -// CHECK-LABEL: func @maxi -func.func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "stablehlo.maximum"(%lhs, %rhs) - : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} -// CHECK: tensor.empty() : tensor<2x2xi32> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxsi %[[LHS_IN]], %[[RHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.maxsi - -// ----- - -// CHECK-LABEL: func @maxu -func.func @maxu(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> { - %0 = "stablehlo.maximum"(%lhs, %rhs) - : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> - func.return %0 : tensor<2x2xui32> -} -// CHECK: tensor.empty() : tensor<2x2xi32> -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %{{.*}}: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxui %[[LHS_IN]], %[[RHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.maxui - -// ----- - -// CHECK-LABEL: func @maxi1 -func.func @maxi1(%lhs: tensor, %rhs: tensor) -> tensor { - %0 = "stablehlo.maximum"(%lhs, %rhs) - : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i1, %[[RHS_IN:.*]]: i1, %{{.*}}: i1): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxui %[[LHS_IN]], %[[RHS_IN]] : i1 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE: arith.maxui - - -// ----- - -// CHECK-LABEL: @clamp_static -// CHECK-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32> -func.func @clamp_static(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) - -> tensor<4xf32> { - // CHECK: %[[INIT:.*]] = tensor.empty - // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) - // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32): - // CHECK: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 - // CHECK: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 - // CHECK: linalg.yield %[[MIN]] - // CHECK: } -> tensor<4xf32> - // CHECK: return %[[RESULT]] : tensor<4xf32> - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor<4xf32>, - tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// CHECK-PRIMITIVE-LABEL: @clamp_static -// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32> - -// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty -// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) -// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32) -// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 -// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[MIN]] -// CHECK-PRIMITIVE: return %[[RESULT]] : tensor<4xf32> - -// ----- - -// CHECK-LABEL: @clamp_dynamic -// CHECK-SAME: %[[LB:.*]]: tensor, %[[X:.*]]: tensor, %[[UB:.*]]: tensor -func.func @clamp_dynamic(%lb : tensor, %x : tensor, %ub : tensor) - -> tensor { - // CHECK: %[[INIT:.*]] = tensor.empty - // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor, tensor, tensor) outs(%[[INIT]] : tensor) - // CHECK: ^bb0(%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32, %{{.*}}: f32): - // CHECK: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 - // CHECK: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 - // CHECK: linalg.yield %[[MIN]] - // CHECK: } -> tensor - // CHECK: return %[[RESULT]] : tensor - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-PRIMITIVE-LABEL: @clamp_dynamic -// CHECK-PRIMITIVE: linalg.map - -// ----- - -func.func @clamp_mixed(%lb : tensor<4xf32>, %x : tensor, %ub : tensor) - -> tensor { - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @clamp_mixed -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @clamp_mixed -// CHECK-PRIMITIVE: linalg.map - -// ----- - -func.func @clamp_scalar(%lb : tensor, %x : tensor, %ub : tensor) - -> tensor { - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @clamp_scalar -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @clamp_scalar -// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor, %[[X:.*]]: tensor, %[[UB:.*]]: tensor - -// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.empty -// CHECK-PRIMITIVE-DAG: %[[SCALAR_LB:.*]] = tensor.extract %[[LB]] -// CHECK-PRIMITIVE-DAG: %[[SCALAR_UB:.*]] = tensor.extract %[[UB]] -// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[X]] : tensor) outs(%[[INIT]] : tensor) -// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32) -// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maximumf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 -// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minimumf %[[MAX]], %[[SCALAR_UB]] : f32 -// CHECK-PRIMITIVE: linalg.yield %[[MIN]] -// CHECK-PRIMITIVE: return %[[RESULT]] - - -// ----- - -func.func @clamp_scalar_mixed(%lb : tensor, %x : tensor, %ub : tensor) - -> tensor { - %0 = "stablehlo.clamp"(%lb, %x, %ub) : (tensor, tensor, - tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @clamp_scalar_mixed -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @clamp_scalar_mixed -// CHECK-PRIMITIVE: linalg.map - -// ----- - -// CHECK-LABEL: func @reduce_precision( -// CHECK-DAG: %[[C2:.*]] = arith.constant 1048576 : i32 -// CHECK-DAG: %[[C_21:.*]] = arith.constant 20 : i32 -// CHECK-DAG: %[[C3:.*]] = arith.constant 524287 : i32 -// CHECK-DAG: %[[C4:.*]] = arith.constant -1048576 : i32 -// CHECK-DAG: %[[C5:.*]] = arith.constant 2139095040 : i32 -// CHECK-DAG: %[[C6:.*]] = arith.constant 1090519040 : i32 -// CHECK-DAG: %[[C7:.*]] = arith.constant 1040187392 : i32 -// CHECK-DAG: %[[C8:.*]] = arith.constant -2147483648 : i32 -// CHECK-DAG: %[[C9:.*]] = arith.constant 2147483647 : i32 -// CHECK: linalg.generic -// CHECK: %[[X_AS_INT:.*]] = arith.bitcast %[[IN:.*]] : f32 to i32 -// CHECK: %[[ABS_X:.*]] = arith.andi %[[X_AS_INT]], %[[C9]] -// CHECK: %[[IS_NAN:.*]] = arith.cmpi ugt, %[[ABS_X]], %[[C5]] -// CHECK: %[[MASKED:.*]] = arith.andi %[[X_AS_INT]], %[[C2]] : i32 -// CHECK: %[[V0:.*]] = arith.shrui %[[MASKED]], %[[C_21]] : i32 -// CHECK: %[[V1:.*]] = arith.addi %[[V0]], %[[C3]] : i32 -// CHECK: %[[V2:.*]] = arith.addi %[[X_AS_INT]], %[[V1]] : i32 -// CHECK: %[[V3:.*]] = arith.andi %[[V2]], %[[C4]] : i32 -// CHECK: %[[V4:.*]] = arith.andi %[[V3]], %[[C5]] : i32 -// CHECK: %[[V5:.*]] = arith.cmpi ugt, %[[V4]], %[[C6]] : i32 -// CHECK: %[[V6:.*]] = arith.cmpi ule, %[[V4]], %[[C7]] : i32 -// CHECK: %[[V7:.*]] = arith.andi %[[V3]], %[[C8]] : i32 -// CHECK: %[[V8:.*]] = arith.ori %[[V7]], %[[C5]] : i32 -// CHECK: %[[V9:.*]] = arith.select %[[V5]], %[[V8]], %[[V3]] : i32 -// CHECK: %[[V10:.*]] = arith.select %[[V6]], %[[V7]], %[[V9]] : i32 -// CHECK: %[[CONVERTED:.*]] = arith.bitcast %[[V10]] : i32 to f32 -// CHECK: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[IN]], %[[CONVERTED]] -// CHECK: linalg.yield %[[RESULT]] - -// CHECK-PRIMITIVE-LABEL: func @reduce_precision( -// CHECK-PRIMITIVE: linalg.map -func.func @reduce_precision(%arg0: tensor<1x2x3x4xf32>) - -> tensor<1x2x3x4xf32> { - %0 = "stablehlo.reduce_precision"(%arg0) {exponent_bits=3:i32, mantissa_bits=3:i32} : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> - return %0 : tensor<1x2x3x4xf32> -} - -// ----- - -// CHECK-LABEL: func @integer_not -// CHECK-SAME: (%[[ARG:.+]]: tensor<2x2xi32>) -// CHECK-PRIMITIVE-LABEL: func @integer_not -// CHECK-PRIMITIVE-SAME: (%[[ARG:.+]]: tensor<2x2xi32>) -func.func @integer_not(%arg: tensor<2x2xi32>) -> tensor<2x2xi32> { - // CHECK: %[[CST_N1:.+]] = arith.constant -1 : i32 - // CHECK: linalg.generic - // CHECK: (%[[IN:.+]]: i32, %{{.+}}: i32) - // CHECK: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32 - // CHECK: linalg.yield %[[V_NOT]] : i32 - // CHECK-PRIMITIVE: %[[CST_N1:.+]] = arith.constant -1 : i32 - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: (%[[IN:.+]]: i32) - // CHECK-PRIMITIVE: %[[V_NOT:.+]] = arith.xori %[[IN]], %[[CST_N1]] : i32 - // CHECK-PRIMITIVE: linalg.yield %[[V_NOT]] : i32 - %0 = "stablehlo.not"(%arg) : (tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0 : tensor<2x2xi32> -} - -// ----- - -// CHECK-LABEL: func @float_complex -// CHECK-SAME: (%[[LHS:.+]]: tensor<2x2xf32>, %[[RHS:.+]]: tensor<2x2xf32>) -// CHECK-PRIMITIVE-LABEL: func @float_complex -// CHECK-PRIMITIVE-SAME: (%[[LHS:.+]]: tensor<2x2xf32>, %[[RHS:.+]]: tensor<2x2xf32>) -func.func @float_complex(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xcomplex> { - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x2xcomplex> - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[LHS]], %[[RHS]] - // CHECK: (%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %{{.+}}: complex - // CHECK: %[[RES:.+]] = complex.create %[[IN0]], %[[IN1]] : complex - // CHECK: linalg.yield %[[RES]] : complex - // CHECK-PRIMITIVE: %[[INIT:.+]] = tensor.empty() : tensor<2x2xcomplex> - // CHECK-PRIMITIVE: linalg.map { complex.create } ins(%[[LHS]], %[[RHS]] : tensor<2x2xf32>, tensor<2x2xf32>) - // CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor<2x2xcomplex>) - %0 = "stablehlo.complex"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> - func.return %0 : tensor<2x2xcomplex> -} diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_random.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_random.mlir deleted file mode 100644 index b56d63bf9c64..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_random.mlir +++ /dev/null @@ -1,693 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// CHECK-LABEL: func @rng_uniform_1d -func.func @rng_uniform_1d(%min: tensor, %max: tensor) -> tensor<10xf32> { - %shape = arith.constant dense<[10]> : tensor<1xi32> - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<1xi32>) -> tensor<10xf32> - func.return %0 : tensor<10xf32> -} -// CHECK-DAG: ^{{.+}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL2_CAST:.+]] = arith.uitofp %[[VAL2]] : i32 to f32 -// CHECK-DAG: %[[VAL4:.+]] = arith.mulf %[[VAL2_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addf %[[VAL4]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL5]] : f32 -// CHECK-NEXT: -> tensor<10xf32> - -// ----- - -// CHECK-LABEL: func @rng_uniform_2d -func.func @rng_uniform_2d(%min: tensor, %max: tensor) -> tensor<3x3xf32> { - %shape = arith.constant dense<[3, 3]> : tensor<2xi32> - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<2xi32>) -> tensor<3x3xf32> - func.return %0 : tensor<3x3xf32> -} -// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index -// CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[VAL3:.+]] = arith.addi %[[IDX1_CAST]], %[[VAL2]] : i32 -// CHECK-DAG: %[[VAL4:.+]] = arith.muli %[[VAL3]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addi %[[VAL4]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL5_CAST:.+]] = arith.uitofp %[[VAL5]] : i32 to f32 -// CHECK-DAG: %[[VAL6:.+]] = arith.mulf %[[VAL5_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL7:.+]] = arith.addf %[[VAL6]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL7]] : f32 -// CHECK-NEXT: -> tensor<3x3xf32> - -// ----- - -// CHECK-LABEL: func @rng_uniform_3d -func.func @rng_uniform_3d(%min: tensor, %max: tensor) -> tensor<2x2x2xf32> { - %shape = arith.constant dense<[2, 2, 2]> : tensor<3xi32> - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<3xi32>) -> tensor<2x2x2xf32> - func.return %0 : tensor<2x2x2xf32> -} -// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index -// CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 -// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index -// CHECK-DAG: %[[IDX2_CAST:.+]] = arith.index_cast %[[IDX2]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[VAL3:.+]] = arith.addi %[[IDX1_CAST]], %[[VAL2]] : i32 -// CHECK-DAG: %[[VAL4:.+]] = arith.muli %[[VAL3]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addi %[[VAL4]], %[[CST1]] : i32 -// CHECK-DAG: %[[VAL6:.+]] = arith.addi %[[IDX2_CAST]], %[[VAL5]] : i32 -// CHECK-DAG: %[[VAL7:.+]] = arith.muli %[[VAL6]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL8:.+]] = arith.addi %[[VAL7]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL8_CAST:.+]] = arith.uitofp %[[VAL8]] : i32 to f32 -// CHECK-DAG: %[[VAL6:.+]] = arith.mulf %[[VAL8_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL7:.+]] = arith.addf %[[VAL6]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL7]] : f32 -// CHECK-NEXT: -> tensor<2x2x2xf32> - -// ----- - -// CHECK-LABEL: func @rng_uniform_dynamic_1d -func.func @rng_uniform_dynamic_1d(%min: tensor, %max: tensor, %shape: tensor<1xi32>) -> tensor { - %0 = "stablehlo.rng"(%min, %max, %shape) {rng_distribution = #stablehlo} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[EX:.+]] = tensor.extract %{{.*}}[%[[C0]]] -// CHECK-DAG: %[[IND:.+]] = arith.index_cast %[[EX]] : i32 to index -// CHECK-DAG: %{{.+}} = tensor.empty(%[[IND]]) : tensor -// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f32, %[[MAX:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-DAG: %[[CST0:.+]] = arith.constant 1103515245 : i32 -// CHECK-DAG: %[[CST1:.+]] = arith.constant 12345 : i32 -// CHECK-DAG: %[[CST2:.+]] = arith.constant 2.32830644E-10 : f32 -// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index -// CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 -// CHECK-DAG: %[[VAL1:.+]] = arith.muli %[[IDX0_CAST]], %[[CST0]] : i32 -// CHECK-DAG: %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[CST1]] : i32 -// CHECK-DAG: %[[DIFF:.+]] = arith.subf %[[MAX]], %[[MIN]] : f32 -// CHECK-DAG: %[[FACT:.+]] = arith.mulf %[[DIFF]], %[[CST2]] : f32 -// CHECK-DAG: %[[VAL2_CAST:.+]] = arith.uitofp %[[VAL2]] : i32 to f32 -// CHECK-DAG: %[[VAL4:.+]] = arith.mulf %[[VAL2_CAST]], %[[FACT]] : f32 -// CHECK-DAG: %[[VAL5:.+]] = arith.addf %[[VAL4]], %[[MIN]] : f32 -// CHECK-NEXT: linalg.yield %[[VAL5]] : f32 -// CHECK-NEXT: -> tensor - -// ----- - -// CHECK-LABEL: func.func @three_fry_i64( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) - return %output_state, %output : tensor<2xi64>, tensor<8xi64> -} -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5 : i32 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 4 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : i32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 24 : i32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 16 : i32 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : i32 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 29 : i32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 6 : i32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 26 : i32 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 17 : i32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 15 : i32 -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 19 : i32 -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 13 : i32 -// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 466688986 : i32 -// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 8 : i64 -// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 32 : i64 -// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_20:.*]] = arith.constant 1 : index - -// CHECK-DAG: %[[VAL_21:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[VAL_20]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_22:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[VAL_19]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_23:.*]] = arith.trunci %[[VAL_22]] : i64 to i32 -// CHECK-DAG: %[[VAL_24:.*]] = arith.shrui %[[VAL_22]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_25:.*]] = arith.trunci %[[VAL_24]] : i64 to i32 -// CHECK-DAG: %[[VAL_26:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_27:.*]] = tensor.empty() : tensor<8xi64> -// CHECK-DAG: %[[VAL_28:.*]] = arith.xori %[[VAL_23]], %[[VAL_16]] : i32 -// CHECK-DAG: %[[VAL_29:.*]] = arith.xori %[[VAL_28]], %[[VAL_25]] : i32 - -// CHECK: %[[GENERIC:.*]] = linalg.generic -// CHECK-SAME: {indexing_maps = [#map], iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[VAL_27]] : tensor<8xi64>) - -// CHECK: ^bb0(%[[VAL_31:.*]]: i64): - -// CHECK-DAG: %[[VAL_32:.*]] = linalg.index 0 : index -// CHECK-DAG: %[[VAL_33:.*]] = arith.index_cast %[[VAL_32]] : index to i64 -// CHECK-DAG: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_21]] : i64 -// CHECK-DAG: %[[VAL_35:.*]] = arith.trunci %[[VAL_34]] : i64 to i32 -// CHECK-DAG: %[[VAL_36:.*]] = arith.shrui %[[VAL_34]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_37:.*]] = arith.trunci %[[VAL_36]] : i64 to i32 - -// CHECK-DAG: %[[VAL_38:.*]] = arith.addi %[[VAL_35]], %[[VAL_23]] : i32 -// CHECK-DAG: %[[VAL_39:.*]] = arith.addi %[[VAL_37]], %[[VAL_25]] : i32 - -// CHECK-DAG: %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : i32 -// CHECK-DAG: %[[VAL_41:.*]] = arith.shli %[[VAL_39]], %[[VAL_15]] : i32 -// CHECK-DAG: %[[VAL_42:.*]] = arith.shrui %[[VAL_39]], %[[VAL_14]] : i32 -// CHECK-DAG: %[[VAL_43:.*]] = arith.ori %[[VAL_41]], %[[VAL_42]] : i32 -// CHECK-DAG: %[[VAL_44:.*]] = arith.xori %[[VAL_40]], %[[VAL_43]] : i32 - -// CHECK-DAG: %[[VAL_45:.*]] = arith.addi %[[VAL_40]], %[[VAL_44]] : i32 -// CHECK-DAG: %[[VAL_46:.*]] = arith.shli %[[VAL_44]], %[[VAL_13]] : i32 -// CHECK-DAG: %[[VAL_47:.*]] = arith.shrui %[[VAL_44]], %[[VAL_12]] : i32 -// CHECK-DAG: %[[VAL_48:.*]] = arith.ori %[[VAL_46]], %[[VAL_47]] : i32 -// CHECK-DAG: %[[VAL_49:.*]] = arith.xori %[[VAL_45]], %[[VAL_48]] : i32 - -// CHECK-DAG: %[[VAL_50:.*]] = arith.addi %[[VAL_45]], %[[VAL_49]] : i32 -// CHECK-DAG: %[[VAL_51:.*]] = arith.shli %[[VAL_49]], %[[VAL_11]] : i32 -// CHECK-DAG: %[[VAL_52:.*]] = arith.shrui %[[VAL_49]], %[[VAL_10]] : i32 -// CHECK-DAG: %[[VAL_53:.*]] = arith.ori %[[VAL_51]], %[[VAL_52]] : i32 -// CHECK-DAG: %[[VAL_54:.*]] = arith.xori %[[VAL_50]], %[[VAL_53]] : i32 - -// CHECK-DAG: %[[VAL_55:.*]] = arith.addi %[[VAL_50]], %[[VAL_54]] : i32 -// CHECK-DAG: %[[VAL_56:.*]] = arith.shli %[[VAL_54]], %[[VAL_10]] : i32 -// CHECK-DAG: %[[VAL_57:.*]] = arith.shrui %[[VAL_54]], %[[VAL_11]] : i32 -// CHECK-DAG: %[[VAL_58:.*]] = arith.ori %[[VAL_56]], %[[VAL_57]] : i32 -// CHECK-DAG: %[[VAL_59:.*]] = arith.xori %[[VAL_55]], %[[VAL_58]] : i32 - -// CHECK-DAG: %[[VAL_60:.*]] = arith.addi %[[VAL_55]], %[[VAL_25]] : i32 -// CHECK-DAG: %[[VAL_61:.*]] = arith.addi %[[VAL_59]], %[[VAL_29]] : i32 -// CHECK-DAG: %[[VAL_62:.*]] = arith.addi %[[VAL_61]], %[[VAL_9]] : i32 -// CHECK-DAG: %[[VAL_63:.*]] = arith.addi %[[VAL_60]], %[[VAL_62]] : i32 -// CHECK-DAG: %[[VAL_64:.*]] = arith.shli %[[VAL_62]], %[[VAL_12]] : i32 -// CHECK-DAG: %[[VAL_65:.*]] = arith.shrui %[[VAL_62]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_66:.*]] = arith.ori %[[VAL_64]], %[[VAL_65]] : i32 - -// CHECK: linalg.yield %[[YIELDED:.*]] : i64 - -// Set the updated state. -// CHECK: %[[VAL_159:.*]] = tensor.insert %[[VAL_26]] into %[[ARG0]]{{\[}}%[[VAL_20]]] : tensor<2xi64> - -// CHECK: return %[[VAL_159]], %[[GENERIC:.*]] : tensor<2xi64>, tensor<8xi64> - -// ----- - -// CHECK-LABEL: func.func @three_fry_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) - return %output_state, %output : tensor<2xi64>, tensor<8xi32> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi32> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi32>, tensor<4xi32>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi32> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi32>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi32> into tensor<8xi32> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] - - -// ----- - -// CHECK-LABEL: func.func @three_fry_odd_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_odd_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) - return %output_state, %output : tensor<2xi64>, tensor<7x11xi32> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C42:.+]] = arith.constant 42 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C42]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<42xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<42xi32> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<42xi32>, tensor<42xi32>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [7, 6, 1] : tensor<4xi32> into tensor<7x6x1xi32> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [7, 6, 1] : tensor<4xi32> into tensor<7x6x1xi32> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<7x6x2xi32> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<7x6x2xi32>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %10 -// CHECK-SAME{literal}: [[0], [1, 2]] : tensor<7x6x2xi32> into tensor<7x12xi32> - -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]][0, 0] [7, 11] [1, 1] -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: return %[[INSERTED]], %[[SLICE]] : tensor<2xi64>, tensor<7x11xi32> - -// ----- - -// CHECK-LABEL: func.func @three_fry_i16 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i16(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) - return %output_state, %output : tensor<2xi64>, tensor<8xi16> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi16> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi16> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi16>, tensor<4xi16>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi16> into tensor<4x1xi16> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi16> into tensor<4x1xi16> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi16> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi16>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi16> into tensor<8xi16> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi16> - -// ----- - -// CHECK-LABEL: func.func @three_fry_i8 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @three_fry_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) - return %output_state, %output : tensor<2xi64>, tensor<8xi8> -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C4]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<4xi8> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<4xi8> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]] : tensor<4xi8>, tensor<4xi8>) - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi8> into tensor<4x1xi8> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi8> into tensor<4x1xi8> - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xi8> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x2xi8>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<4x2xi8> into tensor<8xi8> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] : tensor<2xi64>, tensor<8xi8> - -// ----- - -// CHECK-LABEL: func.func @philox_i64 -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi64> - -func.func @philox_i64(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi64>) - return %output_state, %output : tensor<2xi64>, tensor<8xi64> -} - -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant -1767562579 : i32 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1879881855 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant -616729560 : i32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -239350328 : i32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 534103459 : i32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1401181199 : i32 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1684936478 : i32 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -1253254570 : i32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant -1459197799 : i32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 387276957 : i32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant -308364780 : i32 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 2027808484 : i32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 842468239 : i32 -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -626627285 : i32 -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 1993301258 : i32 -// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 1013904242 : i32 -// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 3449720151 : i64 -// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 3528531795 : i64 -// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_20:.*]] = arith.constant -1150833019 : i32 -// CHECK-DAG: %[[VAL_21:.*]] = arith.constant -1640531527 : i32 -// CHECK-DAG: %[[VAL_22:.*]] = arith.constant 4 : i64 -// CHECK-DAG: %[[VAL_23:.*]] = arith.constant 32 : i64 -// CHECK-DAG: %[[VAL_24:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_25:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_19]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_26:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_24]]] : tensor<2xi64> -// CHECK-DAG: %[[VAL_27:.*]] = arith.trunci %[[VAL_26]] : i64 to i32 -// CHECK-DAG: %[[VAL_28:.*]] = arith.shrui %[[VAL_26]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_29:.*]] = arith.trunci %[[VAL_28]] : i64 to i32 -// CHECK-DAG: %[[VAL_30:.*]] = arith.addi %[[VAL_25]], %[[VAL_22]] : i64 -// CHECK-DAG: %[[VAL_31:.*]] = tensor.empty() : tensor<4xi64> -// CHECK-DAG: %[[VAL_32:.*]] = tensor.empty() : tensor<4xi64> -// CHECK-DAG: %[[VAL_33:.*]]:2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} outs(%[[VAL_31]], %[[VAL_32]] : tensor<4xi64>, tensor<4xi64>) { -// CHECK-DAG: ^bb0(%[[VAL_34:.*]]: i64, %[[VAL_35:.*]]: i64): -// CHECK-DAG: %[[VAL_36:.*]] = linalg.index 0 : index -// CHECK-DAG: %[[VAL_37:.*]] = arith.index_cast %[[VAL_36]] : index to i64 -// CHECK-DAG: %[[VAL_38:.*]] = arith.addi %[[VAL_37]], %[[VAL_25]] : i64 -// CHECK-DAG: %[[VAL_39:.*]] = arith.trunci %[[VAL_38]] : i64 to i32 -// CHECK-DAG: %[[VAL_40:.*]] = arith.shrui %[[VAL_38]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_41:.*]] = arith.trunci %[[VAL_40]] : i64 to i32 - -// CHECK-DAG: %[[VAL_42:.*]] = arith.extui %[[VAL_39]] : i32 to i64 -// CHECK-DAG: %[[VAL_43:.*]] = arith.muli %[[VAL_42]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_44:.*]] = arith.shrui %[[VAL_43]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_45:.*]] = arith.trunci %[[VAL_44]] : i64 to i32 -// CHECK-DAG: %[[VAL_46:.*]] = arith.trunci %[[VAL_43]] : i64 to i32 -// CHECK-DAG: %[[VAL_47:.*]] = arith.extui %[[VAL_27]] : i32 to i64 -// CHECK-DAG: %[[VAL_48:.*]] = arith.muli %[[VAL_47]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_49:.*]] = arith.shrui %[[VAL_48]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_50:.*]] = arith.trunci %[[VAL_49]] : i64 to i32 -// CHECK-DAG: %[[VAL_51:.*]] = arith.trunci %[[VAL_48]] : i64 to i32 -// CHECK-DAG: %[[VAL_52:.*]] = arith.xori %[[VAL_50]], %[[VAL_41]] : i32 -// CHECK-DAG: %[[VAL_53:.*]] = arith.xori %[[VAL_52]], %[[VAL_27]] : i32 - -// CHECK-DAG: %[[VAL_54:.*]] = arith.addi %[[VAL_27]], %[[VAL_21]] : i32 -// CHECK-DAG: %[[VAL_55:.*]] = arith.addi %[[VAL_29]], %[[VAL_20]] : i32 -// CHECK-DAG: %[[VAL_56:.*]] = arith.extui %[[VAL_53]] : i32 to i64 -// CHECK-DAG: %[[VAL_57:.*]] = arith.muli %[[VAL_56]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_58:.*]] = arith.shrui %[[VAL_57]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_59:.*]] = arith.trunci %[[VAL_58]] : i64 to i32 -// CHECK-DAG: %[[VAL_60:.*]] = arith.trunci %[[VAL_57]] : i64 to i32 -// CHECK-DAG: %[[VAL_61:.*]] = arith.extui %[[VAL_45]] : i32 to i64 -// CHECK-DAG: %[[VAL_62:.*]] = arith.muli %[[VAL_61]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_63:.*]] = arith.shrui %[[VAL_62]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_64:.*]] = arith.trunci %[[VAL_63]] : i64 to i32 -// CHECK-DAG: %[[VAL_65:.*]] = arith.trunci %[[VAL_62]] : i64 to i32 -// CHECK-DAG: %[[VAL_66:.*]] = arith.xori %[[VAL_64]], %[[VAL_51]] : i32 -// CHECK-DAG: %[[VAL_67:.*]] = arith.xori %[[VAL_66]], %[[VAL_54]] : i32 -// CHECK-DAG: %[[VAL_68:.*]] = arith.xori %[[VAL_59]], %[[VAL_46]] : i32 -// CHECK-DAG: %[[VAL_69:.*]] = arith.xori %[[VAL_68]], %[[VAL_55]] : i32 - -// CHECK-DAG: %[[VAL_70:.*]] = arith.addi %[[VAL_27]], %[[VAL_16]] : i32 -// CHECK-DAG: %[[VAL_71:.*]] = arith.addi %[[VAL_29]], %[[VAL_15]] : i32 -// CHECK-DAG: %[[VAL_72:.*]] = arith.extui %[[VAL_67]] : i32 to i64 -// CHECK-DAG: %[[VAL_73:.*]] = arith.muli %[[VAL_72]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_74:.*]] = arith.shrui %[[VAL_73]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_75:.*]] = arith.trunci %[[VAL_74]] : i64 to i32 -// CHECK-DAG: %[[VAL_76:.*]] = arith.trunci %[[VAL_73]] : i64 to i32 -// CHECK-DAG: %[[VAL_77:.*]] = arith.extui %[[VAL_69]] : i32 to i64 -// CHECK-DAG: %[[VAL_78:.*]] = arith.muli %[[VAL_77]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_79:.*]] = arith.shrui %[[VAL_78]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_80:.*]] = arith.trunci %[[VAL_79]] : i64 to i32 -// CHECK-DAG: %[[VAL_81:.*]] = arith.trunci %[[VAL_78]] : i64 to i32 -// CHECK: %[[VAL_82:.*]] = arith.xori %[[VAL_80]], %[[VAL_65]] : i32 -// CHECK-DAG: %[[VAL_83:.*]] = arith.xori %[[VAL_82]], %[[VAL_70]] : i32 -// CHECK-DAG: %[[VAL_84:.*]] = arith.xori %[[VAL_75]], %[[VAL_60]] : i32 -// CHECK-DAG: %[[VAL_85:.*]] = arith.xori %[[VAL_84]], %[[VAL_71]] : i32 - -// CHECK-DAG: %[[VAL_86:.*]] = arith.addi %[[VAL_27]], %[[VAL_14]] : i32 -// CHECK-DAG: %[[VAL_87:.*]] = arith.addi %[[VAL_29]], %[[VAL_13]] : i32 -// CHECK-DAG: %[[VAL_88:.*]] = arith.extui %[[VAL_83]] : i32 to i64 -// CHECK-DAG: %[[VAL_89:.*]] = arith.muli %[[VAL_88]], %[[VAL_17]] : i64 -// CHECK-DAG: %[[VAL_90:.*]] = arith.shrui %[[VAL_89]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_91:.*]] = arith.trunci %[[VAL_90]] : i64 to i32 -// CHECK-DAG: %[[VAL_92:.*]] = arith.trunci %[[VAL_89]] : i64 to i32 -// CHECK-DAG: %[[VAL_93:.*]] = arith.extui %[[VAL_85]] : i32 to i64 -// CHECK-DAG: %[[VAL_94:.*]] = arith.muli %[[VAL_93]], %[[VAL_18]] : i64 -// CHECK-DAG: %[[VAL_95:.*]] = arith.shrui %[[VAL_94]], %[[VAL_23]] : i64 -// CHECK-DAG: %[[VAL_96:.*]] = arith.trunci %[[VAL_95]] : i64 to i32 -// CHECK-DAG: %[[VAL_97:.*]] = arith.trunci %[[VAL_94]] : i64 to i32 -// CHECK-DAG: %[[VAL_98:.*]] = arith.xori %[[VAL_96]], %[[VAL_81]] : i32 -// CHECK-DAG: %[[VAL_99:.*]] = arith.xori %[[VAL_98]], %[[VAL_86]] : i32 -// CHECK-DAG: %[[VAL_100:.*]] = arith.xori %[[VAL_91]], %[[VAL_76]] : i32 -// CHECK-DAG: %[[VAL_101:.*]] = arith.xori %[[VAL_100]], %[[VAL_87]] : i32 - -// CHECK: linalg.yield %[[YIELDED_1:.*]], %[[YIELDED_2:.*]] : i64, i64 -// CHECK-DAG: %[[VAL_206:.*]] = tensor.expand_shape %[[VAL_207:.*]]#0 {{\[\[}}0, 1]] output_shape [4, 1] : tensor<4xi64> into tensor<4x1xi64> -// CHECK-DAG: %[[VAL_208:.*]] = tensor.expand_shape %[[VAL_207]]#1 {{\[\[}}0, 1]] output_shape [4, 1] : tensor<4xi64> into tensor<4x1xi64> -// CHECK-DAG: %[[VAL_209:.*]] = tensor.empty() : tensor<4x2xi64> -// CHECK-DAG: %[[VAL_213:.*]] = tensor.insert %[[VAL_30]] into %[[VAL_0]]{{\[}}%[[VAL_19]]] : tensor<2xi64> - -// CHECK: return %[[VAL_213]], %[[GENERIC:.*]] : tensor<2xi64>, tensor<8xi64> - - - -// ----- - -// CHECK-LABEL: func.func @philox_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> -func.func @philox_i32(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi32>) - return %output_state, %output : tensor<2xi64>, tensor<8xi32> -} - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) - -// CHECK: %[[CONCAT:.+]] = linalg.generic - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<2x4xi32> into tensor<8xi32> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] - - -// ----- - -// CHECK-LABEL: func.func @philox_128_i32 -// CHECK-SAME: %[[ARG0:.*]]: tensor<3xi64> -func.func @philox_128_i32(%arg0: tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<3xi64>) -> (tensor<3xi64>, tensor<8xi32>) - return %output_state, %output : tensor<3xi64>, tensor<8xi32> -} - -// ----- - -func.func @philox_i32_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<7x11xi32>) - return %output_state, %output : tensor<2xi64>, tensor<7x11xi32> -} - -// CHECK-LABEL: func.func @philox_i32_odd -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C20]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<20xi32> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<20xi32>, tensor<20xi32>, tensor<20xi32>, tensor<20xi32>) - - -// CHECK: %expanded = tensor.expand_shape %[[GENERIC]]#0 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - -// CHECK: %expanded_1 = tensor.expand_shape %[[GENERIC]]#1 -// CHECK-SAME{literal}: [[0, 1]] output_shape [4, 1] : tensor<4xi32> into tensor<4x1xi32> - - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<20x4xi32> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<20x4xi32>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] - - -// CHECK: %[[VAL_213:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1]] output_shape [80, 1] : tensor<80xi32> into tensor<80x1xi32> -// CHECK: %[[VAL_214:.*]] = tensor.extract_slice %[[VAL_213]][0, 0] [77, 1] [1, 1] : tensor<80x1xi32> to tensor<77x1xi32> -// CHECK: %[[VAL_215:.*]] = tensor.collapse_shape %[[VAL_214]] {{\[\[}}0, 1]] : tensor<77x1xi32> into tensor<77xi32> -// CHECK: %[[VAL_216:.*]] = tensor.expand_shape %[[VAL_215]] {{\[\[}}0, 1]] output_shape [7, 11] : tensor<77xi32> into tensor<7x11xi32> -// CHECK: %[[VAL_217:.*]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]]{{\[}}%[[C1]]] : tensor<2xi64> -// CHECK: return %[[VAL_217]], %[[VAL_216]] : tensor<2xi64>, tensor<7x11xi32> - - -// ----- - - -func.func @philox_i64_odd(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi64>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<3x5xi64>) - return %output_state, %output : tensor<2xi64>, tensor<3x5xi64> -} - -// CHECK-LABEL: func.func @philox_i64_odd -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C8]] : i64 - -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<8xi64> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<8xi64> -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST2]], %[[DEST3]] : tensor<8xi64>, tensor<8xi64>) - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x2xi64> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x2xi64>) - -// CHECK-DAG: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] {{\[\[}}0, 1]] : tensor<8x2xi64> into tensor<16xi64> - - -// CHECK-DAG: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1]] output_shape [16, 1] : tensor<16xi64> into tensor<16x1xi64> -// CHECK-DAG: %[[SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0] [15, 1] [1, 1] : tensor<16x1xi64> to tensor<15x1xi64> -// CHECK-DAG: %[[EXPAND_2:.*]] = tensor.collapse_shape %[[SLICE]] {{\[\[}}0, 1]] : tensor<15x1xi64> into tensor<15xi64> -// CHECK-DAG: %[[RESHAPE:.*]] = tensor.expand_shape %[[EXPAND_2]] {{\[\[}}0, 1]] output_shape [3, 5] : tensor<15xi64> into tensor<3x5xi64> -// CHECK-DAG: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: return %[[INSERTED]], %[[RESHAPE]] - -// ----- - -func.func @philox_i16(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi16>) - return %output_state, %output : tensor<2xi64>, tensor<8xi16> -} - -// CHECK-LABEL: func.func @philox_i16 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi16> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi16>, tensor<2xi16>, tensor<2xi16>, tensor<2xi16>) - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4xi16> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4xi16>) - -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[CONCAT]] -// CHECK-SAME{literal}: [[0, 1]] : tensor<2x4xi16> into tensor<8xi16> -// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEWSTATE]] into %[[ARG0]][%[[C1]]] : tensor<2xi64> - -// CHECK: return %[[INSERTED]], %[[COLLAPSE]] - -// ----- - -func.func @philox_i8(%arg0: tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) { - %output_state, %output = "stablehlo.rng_bit_generator"(%arg0) {rng_algorithm = #stablehlo} : (tensor<2xi64>) -> (tensor<2xi64>, tensor<8xi8>) - return %output_state, %output : tensor<2xi64>, tensor<8xi8> -} - -// CHECK-LABEL: func.func @philox_i8 -// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi64> - - //CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - //CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64 - -// Check we update state correctly: -// CHECK: %[[STATE:.+]] = tensor.extract %[[ARG0]][%[[C1]]] : tensor<2xi64> -// CHECK: %[[NEWSTATE:.+]] = arith.addi %[[STATE]], %[[C2]] : i64 - -// CHECK: %[[DEST0:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[DEST1:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[DEST2:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[DEST3:.+]] = tensor.empty() : tensor<2xi8> -// CHECK: %[[GENERIC:.+]]:4 = linalg.generic -// CHECK-SAME: indexing_maps = [#map, #map, #map, #map] -// CHECK-SAME: iterator_types = ["parallel"]} -// CHECK-SAME: outs(%[[DEST0]], %[[DEST1]], %[[DEST2]], %[[DEST3]] : tensor<2xi8>, tensor<2xi8>, tensor<2xi8>, tensor<2xi8>) - -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4xi8> -// CHECK: %[[CONCAT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4xi8>) diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_reduce.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_reduce.mlir deleted file mode 100644 index 313f645bcb1d..000000000000 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_reduce.mlir +++ /dev/null @@ -1,794 +0,0 @@ -// RUN: iree-opt %s --iree-stablehlo-to-linalg --split-input-file \ -// RUN: --canonicalize | FileCheck %s - -// RUN: iree-opt %s --iree-stablehlo-to-linalg="enable-primitive-ops=true" \ -// RUN: --split-input-file --canonicalize | \ -// RUN: FileCheck %s --check-prefix=CHECK-PRIMITIVE - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: @reduce_add -// CHECK-PRIMITIVE-LABEL: @reduce_add -func.func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32> { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.add %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array, someattr} : (tensor<5x4xi32>, tensor) -> tensor<5xi32> - func.return %0 : tensor<5xi32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) -// CHECK-SAME: {someattr} -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK-PRIMITIVE: linalg.reduce { arith.addi {overflowFlags = #arith.overflow} } -// CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [1] {someattr} - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: @reduce_dim0 -// CHECK-PRIMITIVE-LABEL: @reduce_dim0 -func.func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<4xi32> { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.maximum %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<5x4xi32>, tensor) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.maxsi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK-PRIMITIVE: linalg.reduce { arith.maxsi } -// CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [0] - -// ----- - -func.func @reduce_dynamic_output(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.maximum %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<5x4xi32>, tensor) -> tensor - func.return %0 : tensor -} - -// Regression test: just check that this lowers successfully. -// CHECK-LABEL: @reduce_dynamic_output -// CHECK: linalg.generic - -// CHECK-PRIMITIVE-LABEL: @reduce_dynamic_output -// CHECK-PRIMITIVE: linalg.reduce - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: @reduce_init_const -func.func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> { - %cst = arith.constant dense<0xFF800000> : tensor - %0 = "stablehlo.reduce"(%arg0, %cst) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %1 = stablehlo.add %arg1, %arg2 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<1x10xf32>, tensor) -> tensor<1xf32> - func.return %0 : tensor<1xf32> -} -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<1x10xf32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<1xf32>) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addf %[[RHS_IN]], %[[LHS_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK: @reduce_multi_dimensions -func.func @reduce_multi_dimensions(%arg0: tensor<5x4x3xi32>, - %arg1: tensor) -> tensor<4xi32> { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor<5x4x3xi32>, tensor) -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor<5x4x3xi32>) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -// CHECK-LABEL: @reduce_lexicographic_min_complex -// CHECK-PRIMITIVE-LABEL: @reduce_lexicographic_min_complex -func.func @reduce_lexicographic_min_complex(%arg0: tensor>, - %arg1: tensor>) - -> tensor> { - %0 = stablehlo.reduce(%arg0 init: %arg1) - across dimensions = [0, 1, 2] - : (tensor>, tensor>) -> tensor> - reducer(%arg3: tensor>, %arg4: tensor>) { - %1 = stablehlo.real %arg3 : (tensor>) -> tensor - %2 = stablehlo.convert %arg4 : (tensor>) -> tensor - %3 = "stablehlo.compare"(%1, %2) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - %4 = stablehlo.imag %arg3 : (tensor>) -> tensor - %5 = stablehlo.imag %arg4 : (tensor>) -> tensor - %6 = "stablehlo.compare"(%4, %5) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - %7 = "stablehlo.compare"(%1, %2) - {comparison_direction = #stablehlo} - : (tensor, tensor) -> tensor - %8 = "stablehlo.select"(%3, %6, %7) - : (tensor, tensor, tensor) -> tensor - %9 = "stablehlo.select"(%8, %arg3, %arg4) - : (tensor, tensor>, tensor>) - -> tensor> - "stablehlo.return"(%9) : (tensor>) -> () - } - return %0 : tensor> -} - -// CHECK: linalg.generic -// CHECK: complex.re -// CHECK: complex.re -// CHECK: arith.cmpf -// CHECK: complex.im -// CHECK: complex.im -// CHECK: arith.cmpf -// CHECK: arith.cmpf -// CHECK: arith.select - -// CHECK-PRIMITIVE: linalg.reduce -// CHECK-PRIMITIVE: complex.re -// CHECK-PRIMITIVE: complex.re -// CHECK-PRIMITIVE: arith.cmpf -// CHECK-PRIMITIVE: complex.im -// CHECK-PRIMITIVE: complex.im -// CHECK-PRIMITIVE: arith.cmpf -// CHECK-PRIMITIVE: arith.cmpf -// CHECK-PRIMITIVE: arith.select - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @reduce_dynamic(%[[ARG0:.*]]: tensor -func.func @reduce_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "stablehlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4 : tensor): - %1 = stablehlo.add %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = array} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK-DAG: %[[INIT_TENSOR:.*]] = tensor.empty(%[[DIM1]]) -// CHECK-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%{{.*}}tensor) -// CHECK-SAME: outs(%[[FILL_TENSOR]] : tensor) -// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): -// CHECK-NEXT: %[[RESULT:.*]] = arith.addi %[[RHS_IN]], %[[LHS_IN]] : i32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @variadic_reduce -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-LABEL: func @variadic_reduce -// CHECK-PRIMITIVE-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @variadic_reduce(%arg0: tensor<9x2xi32>, %arg1: tensor<9x2xi32>) -> (tensor<2xi32>, tensor<2xi32>) { - %cst0 = stablehlo.constant dense<-2147483648> : tensor - %cst1 = stablehlo.constant dense<0> : tensor - %res0, %res1 = "stablehlo.reduce"(%arg0, %arg1, %cst0, %cst1) ({ - ^bb0(%arg2: tensor, %arg3: tensor, %arg15: tensor, %arg16: tensor): - %669 = "stablehlo.compare"(%arg2, %arg15) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor - %670 = "stablehlo.select"(%669, %arg2, %arg15) : (tensor, tensor, tensor) -> tensor - %671 = "stablehlo.compare"(%arg2, %arg15) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor - %672 = stablehlo.minimum %arg3, %arg16 : tensor - %673 = "stablehlo.select"(%669, %arg3, %arg16) : (tensor, tensor, tensor) -> tensor - %674 = "stablehlo.select"(%671, %672, %673) : (tensor, tensor, tensor) -> tensor - "stablehlo.return"(%670, %674) : (tensor, tensor) -> () - }) {dimensions = array} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor, tensor) -> (tensor<2xi32>, tensor<2xi32>) - func.return %res0, %res1 : tensor<2xi32>, tensor<2xi32> -} -// CHECK-DAG: %[[CST0:.*]] = arith.constant -2147483648 : i32 -// CHECK-DAG: %[[CST1:.*]] = arith.constant 0 : i32 -// CHECK: %[[INIT0:.*]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK: %[[INIT1:.*]] = tensor.empty() : tensor<2xi32> -// CHECK: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK: %[[RES:.+]]:2 = linalg.generic { -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>) -// CHECK-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>) -// CHECK-NEXT: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32): -// CHECK-NEXT: %[[T1:.*]] = arith.cmpi sge, %[[OUT0]], %[[IN0]] : i32 -// CHECK-NEXT: %[[T2:.*]] = arith.select %[[T1]], %[[OUT0]], %[[IN0]] : i32 -// CHECK-NEXT: %[[T3:.*]] = arith.cmpi eq, %[[OUT0]], %[[IN0]] : i32 -// CHECK-NEXT: %[[T4:.*]] = arith.minsi %[[OUT1:.*]], %[[IN1]] : i32 -// CHECK-NEXT: %[[T5:.*]] = arith.select %[[T1]], %[[OUT1]], %[[IN1]] : i32 -// CHECK-NEXT: %[[T6:.*]] = arith.select %[[T3]], %[[T4]], %[[T5]] : i32 -// CHECK-NEXT: linalg.yield %[[T2]], %[[T6]] - -// CHECK-PRIMITIVE-DAG: %[[CST0:.*]] = arith.constant -2147483648 : i32 -// CHECK-PRIMITIVE-DAG: %[[CST1:.*]] = arith.constant 0 : i32 -// CHECK-PRIMITIVE: %[[INIT0:.*]] = tensor.empty() : tensor<2xi32> -// CHECK-PRIMITIVE: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK-PRIMITIVE: %[[INIT1:.*]] = tensor.empty() : tensor<2xi32> -// CHECK-PRIMITIVE: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.reduce -// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [0] -// CHECK-PRIMITIVE-NEXT: (%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32) { -// CHECK-PRIMITIVE-NEXT: %[[T1:.*]] = arith.cmpi sge, %[[OUT0]], %[[IN0]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T2:.*]] = arith.select %[[T1]], %[[OUT0]], %[[IN0]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T3:.*]] = arith.cmpi eq, %[[OUT0]], %[[IN0]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T4:.*]] = arith.minsi %[[OUT1:.*]], %[[IN1]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T5:.*]] = arith.select %[[T1]], %[[OUT1]], %[[IN1]] : i32 -// CHECK-PRIMITIVE-NEXT: %[[T6:.*]] = arith.select %[[T3]], %[[T4]], %[[T5]] : i32 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[T2]], %[[T6]] - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @variadic_diff_type_reduce -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-LABEL: func @variadic_diff_type_reduce -// CHECK-PRIMITIVE-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-PRIMITIVE-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @variadic_diff_type_reduce(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128xf32>, tensor<128xi32>) { - %cst0 = stablehlo.constant dense<1.0> : tensor - %cst1 = stablehlo.constant dense<1> : tensor - %res0, %res1 = "stablehlo.reduce"(%arg0, %arg1, %cst0, %cst1) ({ - ^bb0(%arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor): - %0 = "stablehlo.compare"(%arg7, %arg9) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor - %1 = "stablehlo.select"(%0, %arg7, %arg9) : (tensor, tensor, tensor) -> tensor - %2 = "stablehlo.select"(%0, %arg8, %arg10) : (tensor, tensor, tensor) -> tensor - "stablehlo.return"(%1, %2) : (tensor, tensor) -> () - }) {dimensions = array} : (tensor<128x10xf32>, tensor<128x10xi32>, tensor, tensor) ->(tensor<128xf32>, tensor<128xi32>) - func.return %res0, %res1 : tensor<128xf32>, tensor<128xi32> -} -// CHECK-DAG: %[[CST0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 -// CHECK: %[[INIT0:.*]] = tensor.empty() : tensor<128xf32> -// CHECK: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK: %[[INIT1:.*]] = tensor.empty() : tensor<128xi32> -// CHECK: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK: %[[RES:.+]]:2 = linalg.generic { -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<128x10xf32>, tensor<128x10xi32>) -// CHECK-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<128xf32>, tensor<128xi32>) -// CHECK-NEXT: ^bb0(%[[LHS0:.*]]: f32, %[[LHS1:.*]]: i32, %[[RHS0:.*]]: f32, %[[RHS1:.*]]: i32): -// CHECK-NEXT: %[[B0:.*]] = arith.cmpf oge, %[[RHS0]], %[[LHS0]] : f32 -// CHECK-NEXT: %[[RES0:.*]] = arith.select %[[B0]], %[[RHS0]], %[[LHS0]] : f32 -// CHECK-NEXT: %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32 -// CHECK-NEXT: linalg.yield %[[RES0]], %[[RES1]] : f32, i32 - -// CHECK-PRIMITIVE-DAG: %[[CST0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-PRIMITIVE-DAG: %[[CST1:.*]] = arith.constant 1 : i32 -// CHECK-PRIMITIVE: %[[INIT0:.*]] = tensor.empty() : tensor<128xf32> -// CHECK-PRIMITIVE: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] -// CHECK-PRIMITIVE: %[[INIT1:.*]] = tensor.empty() : tensor<128xi32> -// CHECK-PRIMITIVE: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.reduce -// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<128x10xf32>, tensor<128x10xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<128xf32>, tensor<128xi32>) -// CHECK-PRIMITIVE-SAME: dimensions = [1] -// CHECK-PRIMITIVE-NEXT: (%[[LHS0:.*]]: f32, %[[LHS1:.*]]: i32, %[[RHS0:.*]]: f32, %[[RHS1:.*]]: i32) { -// CHECK-PRIMITIVE-NEXT: %[[B0:.*]] = arith.cmpf oge, %[[RHS0]], %[[LHS0]] : f32 -// CHECK-PRIMITIVE-NEXT: %[[RES0:.*]] = arith.select %[[B0]], %[[RHS0]], %[[LHS0]] : f32 -// CHECK-PRIMITIVE-NEXT: %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[RES0]], %[[RES1]] : f32, i32 - -// ----- - -// Make sure we do not crash on unsupported reductions. - -// CHECK-LABEL: func.func @reduce_noop -// CHECK: stablehlo.reduce -// CHECK-PRIMITIVE-LABEL: func.func @reduce_noop -// CHECK-PRIMITIVE: stablehlo.reduce -func.func @reduce_noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - %0 = stablehlo.constant dense<0.000000e+00> : tensor - %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - reducer(%arg1: tensor, %arg2: tensor) { - %4 = stablehlo.add %arg1, %arg2 : tensor - stablehlo.return %4 : tensor - } - func.return %1 : tensor<4x8xf32> -} - -// CHECK-LABEL: func.func @reduce_zero_ext -// CHECK: stablehlo.reduce -// CHECK-PRIMITIVE-LABEL: func.func @reduce_zero_ext -// CHECK-PRIMITIVE: stablehlo.reduce -func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor { - %0 = stablehlo.constant dense : tensor - %1 = stablehlo.constant dense : tensor<0xi1> - %2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> - %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32> - %4 = stablehlo.constant dense<0> : tensor - %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor) -> tensor - reducer(%arg1: tensor, %arg2: tensor) { - %6 = stablehlo.add %arg1, %arg2 : tensor - stablehlo.return %6 : tensor - } - return %5 : tensor -} - -// ----- - -// CHECK-LABEL: func @reduce_window_min_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_min_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.minimum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array, - someattr} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %0 : tensor<1x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_min -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: someattr, -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_max_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_max_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.maximum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %0 : tensor<1x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_sum_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %0 : tensor<1x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_max_nhwc_with_cst -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -func.func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x17x17x64xf32>) -> tensor<1x8x8x64xf32> { - %0 = arith.constant dense<0xFF800000> : tensor - %1 = "stablehlo.reduce_window"(%arg0, %0) ({ - ^bb0(%arg1: tensor, %arg2 : tensor): - %2 = stablehlo.maximum %arg1, %arg2 : tensor - "stablehlo.return"(%2) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor) -> tensor<1x8x8x64xf32> - func.return %1 : tensor<1x8x8x64xf32> -} - -// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32 -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_max_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_sum_max_nhwc(%arg0: tensor<1x17x17x64xf32>, - %arg1: tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) { - %0:2 = "stablehlo.reduce_window"(%arg0, %arg0, %arg1, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor, %arg4: tensor, %arg5 : tensor): - %1 = stablehlo.add %arg2, %arg4 : tensor - %2 = stablehlo.maximum %arg3, %arg5 : tensor - "stablehlo.return"(%1, %2) : (tensor, tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x64xf32>, tensor<1x17x17x64xf32>, tensor, tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) - func.return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32> -} - -// CHECK: %[[WINDOW0:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT0:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL0:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES0:.+]] = linalg.pooling_nhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW0]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[WINDOW1:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[INIT1:.+]] = tensor.empty() : tensor<1x8x8x64xf32> -// CHECK: %[[INIT_VAL1:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL1:.+]] = linalg.fill ins(%[[INIT_VAL1]] : f32) outs(%[[INIT1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: %[[RES1:.+]] = linalg.pooling_nhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW1]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> -// CHECK: return %[[RES0]], %[[RES1]] - -// ----- - -// Just check that this lowers successfully. -// CHECK-LABEL: func @reduce_window_unsigned -func.func @reduce_window_unsigned(%arg0: tensor<1x1xui32>) -> tensor<1x1xui32> { - %0 = stablehlo.constant dense<0> : tensor - %1 = "stablehlo.reduce_window"(%arg0, %0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - stablehlo.return %arg1 : tensor - }) { - window_dimensions = array, - window_strides = array - } : (tensor<1x1xui32>, tensor) -> tensor<1x1xui32> - return %1 : tensor<1x1xui32> -} - -// ----- - -// CHECK-LABEL: func @dynamic_reduce_window_sum_nhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @dynamic_reduce_window_sum_nhwc(%arg0: tensor, - %arg1: tensor) -> tensor{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32> -// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[C3]] -// CHECK: %[[T3:.+]] = arith.divui %[[T2]], %[[C2]] -// CHECK: %[[D1:.+]] = arith.addi %[[T3]], %[[C1]] -// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[C3]] -// CHECK: %[[T3:.+]] = arith.divui %[[T2]], %[[C2]] -// CHECK: %[[D2:.+]] = arith.addi %[[T3]], %[[C1]] -// CHECK: %[[D3:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) : tensor -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor) -> tensor -// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<2xi64> -// CHECK-SAME: strides = dense<2> : vector<2xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor, tensor<3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor - -// ----- - -// CHECK-LABEL: func @reduce_window_min_ndhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_min_ndhwc(%arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.minimum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x8x64xf32> - func.return %0 : tensor<1x8x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_min -// CHECK-SAME: {dilations = dense<1> : vector<3xi64> -// CHECK-SAME: strides = dense<2> : vector<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_max_ndhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_max_ndhwc(%arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.maximum %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x8x64xf32> - func.return %0 : tensor<1x8x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_max -// CHECK-SAME: {dilations = dense<1> : vector<3xi64> -// CHECK-SAME: strides = dense<2> : vector<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_ndhwc -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_sum_ndhwc(%arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x8x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x8x64xf32> - func.return %0 : tensor<1x8x8x8x64xf32> -} -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32> -// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> -// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_sum -// CHECK-SAME: {dilations = dense<1> : vector<3xi64> -// CHECK-SAME: strides = dense<2> : vector<3xi64>} -// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>) -// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_sum_ndhwc_dilated_base -// CHECK: linalg.generic -func.func @reduce_window_sum_ndhwc_dilated_base( - %arg0: tensor<1x17x17x17x64xf32>, - %arg1: tensor) -> tensor<1x8x8x16x64xf32>{ - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, - window_dimensions = array, - window_strides = array} : (tensor<1x17x17x17x64xf32>, tensor) -> tensor<1x8x8x16x64xf32> - func.return %0 : tensor<1x8x8x16x64xf32> -} - -// ----- - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 * 2, d1 + d2 * 2)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK: func @reduce_window_generic -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic(%arg0: tensor<4x6xf32>, %arg1: tensor) -> tensor<4x7xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x6xf32>, tensor) -> tensor<4x7xf32> - func.return %0 : tensor<4x7xf32> -} -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7xf32> -// CHECK: %[[FILL:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor) outs(%[[INIT]] : tensor<4x7xf32>) -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 -// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 -// CHECK: linalg.yield %[[IN]] : f32 - -// CHECK: %[[PADVAL:.+]] = tensor.extract %arg1[] : tensor -// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1] high[3, 2] -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index -// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index -// CHECK: tensor.yield %[[PADVAL]] : f32 - -// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<2xf32> -// CHECK: %[[REDUCE:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[PAD]], %[[WINDOW]] : tensor<7x9xf32>, tensor<2xf32>) outs(%[[FILL]] : tensor<4x7xf32>) { -// CHECK: ^{{[a-z0-9_]*}} -// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 -// CHECK-SAME: %[[IN2:[a-zA-Z0-9_]*]]: f32 -// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 -// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[IN]] : f32 -// CHECK: linalg.yield %[[ADD]] - -// CHECK: return %[[REDUCE]] -// ----- - -// CHECK-LABEL: func @reduce_window_generic_captured_constant -func.func @reduce_window_generic_captured_constant(%arg0: tensor<4x6xf32>, %arg1: tensor) -> tensor<4x7xf32> { - %c2 = stablehlo.constant dense<2.0> : tensor - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - %2 = stablehlo.multiply %1, %c2 : tensor - "stablehlo.return"(%2) : (tensor) -> () - }) {base_dilations = array, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<4x6xf32>, tensor) -> tensor<4x7xf32> - func.return %0 : tensor<4x7xf32> -} - -// CHECK: %[[C2:.*]] = arith.constant 2.0 -// CHECK: linalg.generic -// CHECK: %[[SUM:.*]] = arith.addf -// CHECK: %[[PROD:.*]] = arith.mulf %[[SUM]], %[[C2]] -// CHECK: linalg.yield %[[PROD]] - -// ----- - -// CHECK-LABEL: func @reduce_window_generic_padding -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic_padding(%arg0: tensor<3x6xf32>, %arg1: tensor) -> tensor<3x7xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<3x6xf32>, tensor) -> tensor<3x7xf32> - func.return %0 : tensor<3x7xf32> -} -// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[0, 1] high[3, 2] -// CHECK: tensor.yield %[[PADVAL]] : f32 - -// ----- - -// CHECK-LABEL: func @reduce_window_generic_base_dilation -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor) -> tensor<3x4xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<3x6xf32>, tensor) -> tensor<3x4xf32> - func.return %0 : tensor<3x4xf32> -} -// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<5x6xf32>) -> tensor<5x6xf32> -// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 0] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<5x6xf32> - -// ----- - -// CHECK-LABEL: func @reduce_window_generic_padding_base_dilation -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] -func.func @reduce_window_generic_padding_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor) -> tensor<4x7xf32> { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<3x6xf32>, tensor) -> tensor<4x7xf32> - func.return %0 : tensor<4x7xf32> -} -// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<8x9xf32> -// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<8x9xf32>) -> tensor<8x9xf32> -// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 1] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<8x9xf32> - -// ----- - -// CHECK: #[[MAP:.+]] = affine_map<() -> ()> -// CHECK: func @reduce_window_generic_scalar -func.func @reduce_window_generic_scalar(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = stablehlo.add %arg2, %arg3 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {base_dilations = array, padding = dense<> : tensor<0x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]