From f2440c01fece5a5f9fff54cfca6ebb8f462a0e27 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Fri, 2 Dec 2022 18:01:40 -0600 Subject: [PATCH] StableHLO Versioned Dialect and Compatibility Passes (#278) Add StableHLO compatibility dialect and passes for reading and writing with forward/backward compatibility guarantees. **Note: This is still a prototype implementation and should not be used in production until RFCs have been approved and types have been forked.** ```bash stablehlo-opt [flags] file.mlir --stablehlo-legalize-to-vhlo --vhlo-to-version='target=[version]' Translate versioned dialect to target version (current, 0.3.0, ...) --vhlo-legalize-to-stablehlo ``` Change description: - Introduce VHLO, the Versioned StableHLO Dialect. + This dialect is a shallow copy of StableHLO's in-memory layout. It does not include verifiers or constraints. + Once an op is added to VHLO it must remain unchanged so that it can be guaranteed that a VHLO op is identical across versions. + The first version of VHLO is `0.3.0`. - Conversion passes for compatibility + StableHLO <--> VHLO legalizations. StableHLO is always able to be legalized to/from the latest version of VHLO. + VHLO-to-version. Target previous versions of VHLO ops for forward compatibility. Upgrade to the latest version of VHLO ops to emit StableHLO. - Testing for legalizations and version conversions. Future work (these items will be made into individual GH issues before submit): - Think more about a scalable way to test this as StableHLO evolves. - Additional feature work on the tool.. any missing flags? Pass pipeline for simplicity? - Improve user experience. - See open [compatibility issues](https://github.com/openxla/stablehlo/labels/Compatibility) Closes #255 --- stablehlo/CMakeLists.txt | 1 + stablehlo/dialect/Base.td | 17 +- stablehlo/dialect/CMakeLists.txt | 33 +- stablehlo/dialect/Register.cpp | 4 +- stablehlo/dialect/StablehloAttrs.td | 2 +- stablehlo/dialect/StablehloOps.td | 2 +- stablehlo/dialect/Version.cpp | 68 + stablehlo/dialect/Version.h | 66 + stablehlo/dialect/VhloAttrs.td | 172 ++ stablehlo/dialect/VhloBase.td | 184 ++ stablehlo/dialect/VhloEnums.td | 201 ++ stablehlo/dialect/VhloOps.cpp | 93 + stablehlo/dialect/VhloOps.h | 78 + stablehlo/dialect/VhloOps.td | 958 +++++++++ .../tests/legalize_stablehlo_to_vhlo.mlir | 1770 +++++++++++++++++ stablehlo/tests/ops_chlo_roundtrip.mlir | 1 - .../tests/vhlo_to_version_downgrade.mlir | 41 + .../vhlo_to_version_downgrade_invalid.mlir | 58 + .../vhlo_to_version_invalid_target_empty.mlir | 2 + ...vhlo_to_version_invalid_target_future.mlir | 2 + ...hlo_to_version_invalid_target_minimum.mlir | 2 + ...to_version_invalid_target_not_version.mlir | 2 + stablehlo/tests/vhlo_to_version_upgrade.mlir | 23 + stablehlo/tools/CMakeLists.txt | 1 + stablehlo/tools/StablehloOptMain.cpp | 2 + stablehlo/transforms/CMakeLists.txt | 43 + .../transforms/LegalizeStablehloToVhlo.cpp | 227 +++ .../transforms/LegalizeVhloToStablehlo.cpp | 216 ++ stablehlo/transforms/MapStablehloToVhlo.h | 171 ++ stablehlo/transforms/Passes.h | 51 + stablehlo/transforms/Passes.td | 34 + stablehlo/transforms/TypeConversion.cpp | 42 + stablehlo/transforms/TypeConversion.h | 136 ++ stablehlo/transforms/VhloToVersion.cpp | 261 +++ 34 files changed, 4949 insertions(+), 15 deletions(-) create mode 100644 stablehlo/dialect/Version.cpp create mode 100644 stablehlo/dialect/Version.h create mode 100644 stablehlo/dialect/VhloAttrs.td create mode 100644 stablehlo/dialect/VhloBase.td create mode 100644 stablehlo/dialect/VhloEnums.td create mode 100644 stablehlo/dialect/VhloOps.cpp create mode 100644 stablehlo/dialect/VhloOps.h create mode 100644 stablehlo/dialect/VhloOps.td create mode 100644 stablehlo/tests/legalize_stablehlo_to_vhlo.mlir create mode 100644 stablehlo/tests/vhlo_to_version_downgrade.mlir create mode 100644 stablehlo/tests/vhlo_to_version_downgrade_invalid.mlir create mode 100644 stablehlo/tests/vhlo_to_version_invalid_target_empty.mlir create mode 100644 stablehlo/tests/vhlo_to_version_invalid_target_future.mlir create mode 100644 stablehlo/tests/vhlo_to_version_invalid_target_minimum.mlir create mode 100644 stablehlo/tests/vhlo_to_version_invalid_target_not_version.mlir create mode 100644 stablehlo/tests/vhlo_to_version_upgrade.mlir create mode 100644 stablehlo/transforms/CMakeLists.txt create mode 100644 stablehlo/transforms/LegalizeStablehloToVhlo.cpp create mode 100644 stablehlo/transforms/LegalizeVhloToStablehlo.cpp create mode 100644 stablehlo/transforms/MapStablehloToVhlo.h create mode 100644 stablehlo/transforms/Passes.h create mode 100644 stablehlo/transforms/Passes.td create mode 100644 stablehlo/transforms/TypeConversion.cpp create mode 100644 stablehlo/transforms/TypeConversion.h create mode 100644 stablehlo/transforms/VhloToVersion.cpp diff --git a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt index 2370ddc11a..8d64676a89 100644 --- a/stablehlo/CMakeLists.txt +++ b/stablehlo/CMakeLists.txt @@ -17,3 +17,4 @@ add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) add_subdirectory(tools) +add_subdirectory(transforms) diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 32d491a302..e0888499ed 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -17,7 +17,6 @@ limitations under the License. #ifndef STABLEHLO_DIALECT_BASE #define STABLEHLO_DIALECT_BASE -include "mlir/Dialect/Quant/QuantOpsBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" @@ -44,7 +43,7 @@ def HLO_Complex : Complex>; //===----------------------------------------------------------------------===// // TODO(b/230381284): Upstream width-specific uniform quantized element types. -class UniformQuantizedSignedInt +class StableHLO_UniformQuantizedSignedInt : Type()">, CPred<"$_self.cast()" # @@ -61,7 +60,7 @@ class UniformQuantizedSignedInt int bitwidth = width; } -class UniformQuantizedUnsignedInt +class StableHLO_UniformQuantizedUnsignedInt : Type()">, CPred<"$_self.cast()" # @@ -78,20 +77,20 @@ class UniformQuantizedUnsignedInt int bitwidth = width; } -class UniformQuantizedSignedIntOfWidths widths> : - AnyTypeOf), +class StableHLO_UniformQuantizedSignedIntOfWidths widths> : + AnyTypeOf), !interleave(widths, "/") # "-bit uniform quantized signed " # "integer">; -class UniformQuantizedUnsignedIntOfWidths widths> : - AnyTypeOf), +class StableHLO_UniformQuantizedUnsignedIntOfWidths widths> : + AnyTypeOf), !interleave(widths, "/") # "-bit uniform quantized unsigned " # "integer">; // Integer-based uniform quantized types. The definitions can be used to specify // operand's tensor types. -def HLO_QuantizedSignedInt : UniformQuantizedSignedIntOfWidths<[4, 8, 16, 32]>; -def HLO_QuantizedUnsignedInt : UniformQuantizedUnsignedIntOfWidths<[4, 8, 16, 32]>; +def HLO_QuantizedSignedInt : StableHLO_UniformQuantizedSignedIntOfWidths<[4, 8, 16, 32]>; +def HLO_QuantizedUnsignedInt : StableHLO_UniformQuantizedUnsignedIntOfWidths<[4, 8, 16, 32]>; def HLO_QuantizedInt : AnyTypeOf<[HLO_QuantizedSignedInt, HLO_QuantizedUnsignedInt]>; diff --git a/stablehlo/dialect/CMakeLists.txt b/stablehlo/dialect/CMakeLists.txt index be241a78b2..8eb273a1a0 100644 --- a/stablehlo/dialect/CMakeLists.txt +++ b/stablehlo/dialect/CMakeLists.txt @@ -84,15 +84,17 @@ add_mlir_dialect_library(StablehloRegister DEPENDS ChloOpsIncGen StablehloOpsIncGen + VhloOpsIncGen LINK_LIBS PUBLIC ChloOps StablehloOps + VhloOps ) add_mlir_dialect_library(StablehloAssemblyFormat PARTIAL_SOURCES_INTENDED - AssemblyFormat.cpp + AssemblyFormat.cpp LINK_LIBS PUBLIC StablehloBase @@ -101,7 +103,7 @@ add_mlir_dialect_library(StablehloAssemblyFormat add_mlir_dialect_library(StablehloTypeInference PARTIAL_SOURCES_INTENDED - TypeInference.cpp + TypeInference.cpp LINK_LIBS PUBLIC MLIRInferTypeOpInterface @@ -147,3 +149,30 @@ target_include_directories(StablehloOps INTERFACE $ $ ) + +set(LLVM_TARGET_DEFINITIONS VhloOps.td) +mlir_tablegen(VhloOps.h.inc -gen-op-decls) +mlir_tablegen(VhloOps.cpp.inc -gen-op-defs) +mlir_tablegen(VhloOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(VhloOpInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(VhloEnums.h.inc -gen-enum-decls) +mlir_tablegen(VhloEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls) +mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(VhloOpsIncGen) +add_dependencies(mlir-headers VhloOpsIncGen) + +add_mlir_dialect_library(VhloOps + PARTIAL_SOURCES_INTENDED + VhloOps.cpp + Version.cpp + + DEPENDS + VhloOpsIncGen + + LINK_LIBS PUBLIC + StablehloAssemblyFormat + MLIRIR + MLIRSupport + MLIRQuantDialect +) diff --git a/stablehlo/dialect/Register.cpp b/stablehlo/dialect/Register.cpp index 31b94cbb78..99221f02f9 100644 --- a/stablehlo/dialect/Register.cpp +++ b/stablehlo/dialect/Register.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/VhloOps.h" namespace mlir { namespace stablehlo { @@ -27,7 +28,8 @@ void registerAllDialects(mlir::DialectRegistry ®istry) { // clang-format off registry.insert(); registry.insert(); + mlir::stablehlo::StablehloDialect, + mlir::vhlo::VhloDialect>(); // clang-format on } diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td index b069f3c744..99aa6d205b 100644 --- a/stablehlo/dialect/StablehloAttrs.td +++ b/stablehlo/dialect/StablehloAttrs.td @@ -64,7 +64,7 @@ def StableHLO_DotDimensionNumbers : AttrDef { +def StableHLO_OutputOperandAlias : AttrDef { let cppNamespace = "::mlir::stablehlo"; let mnemonic = "output_operand_alias"; let summary = diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index bb13abd9f8..a2969dd3bb 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2101,7 +2101,7 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call", OptionalAttr:$result_layouts, DefaultValuedOptionalAttr< TypedArrayAttrBase< - OutputOperandAlias, + StableHLO_OutputOperandAlias, "Aliasing attribute for outputs and operands of CustomCall">, "{}">:$output_operand_aliases ); diff --git a/stablehlo/dialect/Version.cpp b/stablehlo/dialect/Version.cpp new file mode 100644 index 0000000000..bbf4e943c3 --- /dev/null +++ b/stablehlo/dialect/Version.cpp @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/dialect/Version.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Diagnostics.h" + +namespace mlir { +namespace vhlo { +namespace { +// Helper function for number to string. +// Precondition that numRef is a valid decimal digit. +static int64_t parseNumber(llvm::StringRef numRef) { + int64_t num; + if (numRef.getAsInteger(/*radix=*/10, num)) { + llvm_unreachable("failed to parse version number"); + } + return num; +} + +/// Validate version argument is `#.#.#` (ex: 0.3.0, 1.2.3, 0.123.0) +/// Returns the vector of 3 matches (major, minor, patch) if successful, +/// else returns failure. +static FailureOr> extractVersionNumbers( + llvm::StringRef versionRef) { + llvm::Regex versionRegex("^([0-9]+)\\.([0-9]+)\\.([0-9]+)$"); + llvm::SmallVector matches; + if (!versionRegex.match(versionRef, &matches)) { + return failure(); + } + return std::array{parseNumber(matches[1]), + parseNumber(matches[2]), + parseNumber(matches[3])}; +} +} // namespace + +FailureOr Version::fromString(llvm::StringRef versionRef) { + auto failOrVersionArray = extractVersionNumbers(versionRef); + if (failed(failOrVersionArray)) { + return failure(); + } + + auto versionArr = *failOrVersionArray; + return Version(versionArr[0], versionArr[1], versionArr[2]); +} + +mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version) { + return diag << version.getMajor() << '.' << version.getMinor() << '.' + << version.getPatch(); +} + +} // namespace vhlo +} // namespace mlir diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h new file mode 100644 index 0000000000..9a4a837044 --- /dev/null +++ b/stablehlo/dialect/Version.h @@ -0,0 +1,66 @@ +/* Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_DIALECT_VERSION_H +#define STABLEHLO_DIALECT_VERSION_H + +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Regex.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace vhlo { + +class Version { + public: + /// Convenience method to extract major, minor, patch and create a Version + /// from a StringRef of the form `#.#.#`. Returns failure if invalid string. + static FailureOr fromString(llvm::StringRef versionRef); + + /// Construct Version from major, minor, patch integers. + Version(int64_t major, int64_t minor, int64_t patch) + : majorMinorPatch({major, minor, patch}) {} + + int64_t getMajor() const { return majorMinorPatch[0]; } + int64_t getMinor() const { return majorMinorPatch[1]; } + int64_t getPatch() const { return majorMinorPatch[2]; } + + bool operator<(Version const& other) { + // Uses lexicographical_compare + return majorMinorPatch < other.majorMinorPatch; + } + bool operator==(Version const& other) { + return majorMinorPatch == other.majorMinorPatch; + } + bool operator<=(Version const& other) { + return majorMinorPatch <= other.majorMinorPatch; + } + + private: + std::array majorMinorPatch; +}; + +mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version); + +} // namespace vhlo +} // namespace mlir + +#endif // STABLEHLO_DIALECT_VERSION_H diff --git a/stablehlo/dialect/VhloAttrs.td b/stablehlo/dialect/VhloAttrs.td new file mode 100644 index 0000000000..a408e7f47c --- /dev/null +++ b/stablehlo/dialect/VhloAttrs.td @@ -0,0 +1,172 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_DIALECT_VHLO_ATTRS +#define STABLEHLO_DIALECT_VHLO_ATTRS + +include "stablehlo/dialect/VhloBase.td" +include "stablehlo/dialect/VhloEnums.td" + +include "mlir/IR/AttrTypeBase.td" + +//===----------------------------------------------------------------------===// +// Attributes +//===----------------------------------------------------------------------===// + +def VHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> { + let parser = "mlir::hlo::parseDimSizes($_parser)"; + let printer = "mlir::hlo::printDimSizes($_printer, $_self)"; +} + +def VHLO_ScatterDimensionNumbers : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "scatter"; + let parameters = (ins + VHLO_Dims:$updateWindowDims, + VHLO_Dims:$insertedWindowDims, + VHLO_Dims:$scatterDimsToOperandDims, + "int64_t":$indexVectorDim + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_GatherDimensionNumbers : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "gather"; + let parameters = (ins + VHLO_Dims:$offsetDims, + VHLO_Dims:$collapsedSliceDims, + VHLO_Dims:$startIndexMap, + "int64_t":$indexVectorDim + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_DotDimensionNumbers : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "dot"; + let parameters = (ins + VHLO_Dims:$lhsBatchingDimensions, + VHLO_Dims:$rhsBatchingDimensions, + VHLO_Dims:$lhsContractingDimensions, + VHLO_Dims:$rhsContractingDimensions + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_OutputOperandAlias : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "output_operand_alias"; + let parameters = (ins + VHLO_Dims:$outputTupleIndices, + "int64_t":$operandIndex, + VHLO_Dims:$operandTupleIndices + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_ArgResultAlias : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "result_alias"; + let parameters = (ins + VHLO_Dims:$argTupleIndices, + "int64_t":$resultIndex, + VHLO_Dims:$resultTupleIndices, + "bool":$isMustAlias + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_ChannelHandle : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "channel_handle"; + let parameters = (ins "int64_t":$handle, "int64_t":$type); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_TypeExtensions : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "type_extensions"; + let parameters = (ins VHLO_Dims:$bounds); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_LayoutAttr : Attr< + And<[IndexElementsAttr.predicate, + CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank() + == 1}]>]>, + "A 1D tensor of index type (layout)"> { + let storageType = IndexElementsAttr.storageType; + let returnType = IndexElementsAttr.returnType; + let convertFromStorage = IndexElementsAttr.convertFromStorage; +} + +// An array of layout (1D tensor) attributes. +def VHLO_ArrayOfLayoutAttr : TypedArrayAttrBase; + +// An array of FlatSymbolRef attributes that can be used as a default valued +// attribute. +def VHLO_FlatSymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)"; +} + +def VHLO_BoolElementsAttr : + ElementsAttrBase< + And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">, + CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, + "constant boolean vector/tensor attribute"> { + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + +def VHLO_ConvDimensionNumbers : AttrDef { + let cppNamespace = "::mlir::vhlo"; + let mnemonic = "conv"; + let parameters = (ins + "int64_t":$inputBatchDimension, + "int64_t":$inputFeatureDimension, + VHLO_Dims:$inputSpatialDimensions, + + "int64_t":$kernelInputFeatureDimension, + "int64_t":$kernelOutputFeatureDimension, + VHLO_Dims:$kernelSpatialDimensions, + + "int64_t":$outputBatchDimension, + "int64_t":$outputFeatureDimension, + VHLO_Dims:$outputSpatialDimensions + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def VHLO_ConvolutionAttributes { + dag attributes = (ins + OptionalAttr:$window_strides, + OptionalAttr:$padding, + OptionalAttr:$lhs_dilation, + OptionalAttr:$rhs_dilation, + OptionalAttr:$window_reversal, + VHLO_ConvDimensionNumbers:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + VHLO_PrecisionConfigAttr:$precision_config + ); +} + +#endif // STABLEHLO_DIALECT_VHLO_ATTRS diff --git a/stablehlo/dialect/VhloBase.td b/stablehlo/dialect/VhloBase.td new file mode 100644 index 0000000000..9537f826c8 --- /dev/null +++ b/stablehlo/dialect/VhloBase.td @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_DIALECT_VHLO_BASE +#define STABLEHLO_DIALECT_VHLO_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// VHLO Versioning Interfaces +//===----------------------------------------------------------------------===// + +def VersionedInterface : OpInterface<"VersionedInterface"> { + let methods = [ + InterfaceMethod< + "Returns the minimum version an op is supported in.", + "mlir::vhlo::Version", "getMinVersion">, + InterfaceMethod< + "Returns the maximum verison an op is supported in.", + "mlir::vhlo::Version", "getMaxVersion">, + ]; +} + +//===----------------------------------------------------------------------===// +// VHLO Type Definitions. +//===----------------------------------------------------------------------===// + +def VHLO_Pred : TypeAlias; + +// TODO(hinsu): Use signed integers instead of signless integer which is being +// used for legacy reasons. +def VHLO_SInt : SignlessIntOfWidths<[4, 8, 16, 32, 64]>; +def VHLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>; +def VHLO_Int : AnyTypeOf<[VHLO_SInt, VHLO_UInt]>; + +def VHLO_Float : AnyTypeOf<[F16, F32, F64, BF16]>; +def VHLO_Float32Or64 : AnyTypeOf<[F32, F64]>; + +def VHLO_Complex : Complex>; + +//===----------------------------------------------------------------------===// +// Quantized element type definitions. +//===----------------------------------------------------------------------===// + +// TODO(b/230381284): Upstream width-specific uniform quantized element types. +class VHLO_UniformQuantizedSignedInt + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # width>, + CPred<"$_self.cast()" # + ".isSigned()">]>, + And<[CPred<"$_self.isa()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # width>, + CPred<"$_self.cast()" # + ".isSigned()">]>]>, + "QI" # width # " type"> { + string name = "UniformQuantizedSignedInt"; + int bitwidth = width; +} + +class VHLO_UniformQuantizedUnsignedInt + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # width>, + CPred<"!$_self.cast()" # + ".isSigned()">]>, + And<[CPred<"$_self.isa()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # width>, + CPred<"!$_self.cast()" # + ".isSigned()">]>]>, + "QUI" # width # " type"> { + string name = "UniformQuantizedUnsignedInt"; + int bitwidth = width; +} + +class VHLO_UniformQuantizedSignedIntOfWidths widths> : + AnyTypeOf), + !interleave(widths, "/") # "-bit uniform quantized signed " # + "integer">; + +class VHLO_UniformQuantizedUnsignedIntOfWidths widths> : + AnyTypeOf), + !interleave(widths, "/") # "-bit uniform quantized unsigned " # + "integer">; + +// Integer-based uniform quantized types. The definitions can be used to specify +// operand's tensor types. +def VHLO_QuantizedSignedInt : VHLO_UniformQuantizedSignedIntOfWidths<[4, 8, 16, 32]>; +def VHLO_QuantizedUnsignedInt : VHLO_UniformQuantizedUnsignedIntOfWidths<[4, 8, 16, 32]>; +def VHLO_QuantizedInt : + AnyTypeOf<[VHLO_QuantizedSignedInt, VHLO_QuantizedUnsignedInt]>; + +// The broadcasting dimensions correspond to a tuple that describes how a +// smaller rank shape is broadcast into a larger rank shape. For example, +// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means +// matching the matrix to dimensions 1 and 2 of the cuboid. +defvar BroadcastDimAttr = I64ElementsAttr; + +// Token type. +def VHLO_Token : Type()">, "token">; + +// Any integer tensor types +def VHLO_IntTensor : TensorOf<[VHLO_Int]>; + +// Any integer tensor type with rank 0 (i.e. representing a single integer). +def VHLO_ScalarIntTensor : 0DTensorOf<[VHLO_Int]>; + +// Any floating-point tensor types +def VHLO_FpTensor : TensorOf<[VHLO_Float]>; + +// 32 or 64 bits floating-point tensor types +def VHLO_Fp32Or64Tensor : TensorOf<[VHLO_Float32Or64]>; + +// Any quantized integer tensor types +def VHLO_QuantizedIntTensor : TensorOf<[VHLO_QuantizedInt]>; + +def VHLO_PredTensor : TensorOf<[VHLO_Pred]>; + +def VHLO_Tensor : TensorOf<[VHLO_Float, VHLO_Pred, VHLO_Int, VHLO_Complex, VHLO_QuantizedInt]>; + +def VHLO_ComplexTensor : TensorOf<[VHLO_Complex]>; + +def VHLO_Tuple : NestedTupleOf<[VHLO_Tensor, VHLO_Token]>; + +def VHLO_TensorOrToken : AnyTypeOf<[VHLO_Tensor, VHLO_Token]>; + +def VHLO_TensorOrTokenOrTuple : AnyTypeOf<[VHLO_Tensor, VHLO_Token, VHLO_Tuple]>; + +def VHLO_DimensionValue : AnyTypeOf<[Index, VHLO_Int]>; + +// Dynamic representation of a shape vector as a tensor. +def VHLO_DimensionTensor : 1DTensorOf<[VHLO_DimensionValue]>; + +// In general, static shaped tensor constraints should be avoided unless +// it is for a legacy op which is only correct with static shapes. +def VHLO_StaticShapeTensor : StaticShapeTensorOf<[ + VHLO_Float, VHLO_Pred, VHLO_Int, VHLO_Complex, VHLO_QuantizedInt]>; + +//===----------------------------------------------------------------------===// +// VHLO combined type definitions. +//===----------------------------------------------------------------------===// + +// Any integer or floating-point tensor types +def VHLO_IntOrFpTensor : TensorOf<[VHLO_Int, VHLO_Float]>; + +// Any integer or predicate tensor types +def VHLO_PredOrIntTensor : TensorOf<[VHLO_Pred, VHLO_Int]>; + +// Any floating-point or complex tensor types +def VHLO_FpOrComplexTensor : TensorOf<[VHLO_Float, VHLO_Complex]>; + +// Any int, floating-point or complex tensor types +def VHLO_IntFpOrComplexTensor : TensorOf<[VHLO_Int, VHLO_Float, VHLO_Complex]>; + +// Any pred, int or floating-point tensor types +def VHLO_PredIntOrFpTensor : TensorOf<[VHLO_Pred, VHLO_Int, VHLO_Float]>; + +//===----------------------------------------------------------------------===// +// VHLO traits +//===----------------------------------------------------------------------===// + +// VHLO intentionally does not include traits since it is only used to represent +// the in-memory structure of the IR at a given version. +// +// This section is included for file parity between VhloBase.td and Base.td + +#endif // STABLEHLO_DIALECT_VHLO_BASE diff --git a/stablehlo/dialect/VhloEnums.td b/stablehlo/dialect/VhloEnums.td new file mode 100644 index 0000000000..0a1f0e8431 --- /dev/null +++ b/stablehlo/dialect/VhloEnums.td @@ -0,0 +1,201 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_DIALECT_VHLO_ENUMS +#define STABLEHLO_DIALECT_VHLO_ENUMS + +include "stablehlo/dialect/VhloBase.td" + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/PatternBase.td" + +//===----------------------------------------------------------------------===// +// Enumerations +//===----------------------------------------------------------------------===// + +// These mirror the XLA PrecisionConfig proto enum. +def VHLO_PRECISION_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; +def VHLO_PRECISION_HIGH : I32EnumAttrCase<"HIGH", 1>; +def VHLO_PRECISION_HIGHEST : I32EnumAttrCase<"HIGHEST", 2>; + +def VHLO_Precision : I32EnumAttr<"Precision", + "XLA precision for an operand. Has backend specific meaning.", + [ + VHLO_PRECISION_DEFAULT, + VHLO_PRECISION_HIGH, + VHLO_PRECISION_HIGHEST + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vhlo"; +} + +def VHLO_PrecisionAttr : EnumAttr; + +// TODO(b/129153247) See if it's possible to also validate the size. +def VHLO_PrecisionConfigAttr: + OptionalAttr< + TypedArrayAttrBase>; + +//===----------------------------------------------------------------------===// +// Fast Fourier Transform Type enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA FftType proto enum. +def VHLO_FFT_TYPE_FFT : I32EnumAttrCase<"FFT", 0>; +def VHLO_FFT_TYPE_IFFT : I32EnumAttrCase<"IFFT", 1>; +def VHLO_FFT_TYPE_RFFT : I32EnumAttrCase<"RFFT", 2>; +def VHLO_FFT_TYPE_IRFFT : I32EnumAttrCase<"IRFFT", 3>; + +def VHLO_FftType : I32EnumAttr<"FftType", + "XLA fast fourier transform type.", + [ + VHLO_FFT_TYPE_FFT, + VHLO_FFT_TYPE_IFFT, + VHLO_FFT_TYPE_RFFT, + VHLO_FFT_TYPE_IRFFT + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vhlo"; +} + +def VHLO_FftTypeAttr : EnumAttr; + +//===----------------------------------------------------------------------===// +// Custom call enum definitions. +//===----------------------------------------------------------------------===// + +def VHLO_CUSTOM_CALL_API_VERISON_UNSPECIFIED : + I32EnumAttrCase<"API_VERSION_UNSPECIFIED", 0>; +def VHLO_CUSTOM_CALL_API_VERSION_ORIGINAL : + I32EnumAttrCase<"API_VERSION_ORIGINAL", 1>; +def VHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING : + I32EnumAttrCase<"API_VERSION_STATUS_RETURNING", 2>; +def VHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED : + I32EnumAttrCase<"API_VERSION_STATUS_RETURNING_UNIFIED", 3>; +def VHLO_CustomCallApiVersionAttr : + I32EnumAttr<"CustomCallApiVersion", "Custom call API version", [ + VHLO_CUSTOM_CALL_API_VERISON_UNSPECIFIED, + VHLO_CUSTOM_CALL_API_VERSION_ORIGINAL, + VHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING, + VHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED + ]> { + let cppNamespace = "::mlir::vhlo"; +} + +//===----------------------------------------------------------------------===// +// Comparison op definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA ComparisonDirection enum. +def VHLO_COMPARISON_DIRECTION_EQ : I32EnumAttrCase<"EQ", 0>; +def VHLO_COMPARISON_DIRECTION_NE : I32EnumAttrCase<"NE", 1>; +def VHLO_COMPARISON_DIRECTION_GE : I32EnumAttrCase<"GE", 2>; +def VHLO_COMPARISON_DIRECTION_GT : I32EnumAttrCase<"GT", 3>; +def VHLO_COMPARISON_DIRECTION_LE : I32EnumAttrCase<"LE", 4>; +def VHLO_COMPARISON_DIRECTION_LT : I32EnumAttrCase<"LT", 5>; + +def VHLO_ComparisonDirection : I32EnumAttr<"ComparisonDirection", + "Which comparison operation to perform.", + [ + VHLO_COMPARISON_DIRECTION_EQ, + VHLO_COMPARISON_DIRECTION_NE, + VHLO_COMPARISON_DIRECTION_GE, + VHLO_COMPARISON_DIRECTION_GT, + VHLO_COMPARISON_DIRECTION_LE, + VHLO_COMPARISON_DIRECTION_LT + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vhlo"; +} + +def VHLO_ComparisonDirectionAttr : EnumAttr; + +def VHLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"::mlir::vhlo::ComparisonTypeAttr()">; +def VHLO_COMPARISON_TYPE_NOTYPE : I32EnumAttrCase<"NOTYPE", 0>; +def VHLO_COMPARISON_TYPE_FLOAT : I32EnumAttrCase<"FLOAT", 1>; +def VHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : I32EnumAttrCase<"TOTALORDER", 2>; +def VHLO_COMPARISON_TYPE_SIGNED : I32EnumAttrCase<"SIGNED", 3>; +def VHLO_COMPARISON_TYPE_UNSIGNED : I32EnumAttrCase<"UNSIGNED", 4>; + +def VHLO_ComparisonType : I32EnumAttr<"ComparisonType", + "Which comparison type to use.", + [ + VHLO_COMPARISON_TYPE_NOTYPE, + VHLO_COMPARISON_TYPE_FLOAT, + VHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + VHLO_COMPARISON_TYPE_SIGNED, + VHLO_COMPARISON_TYPE_UNSIGNED + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vhlo"; +} + +def VHLO_ComparisonTypeAttr + : EnumAttr; + +// These mirror the XLA Transpose enum in Triangular Solve options. +def VHLO_TRANSPOSE_INVALID : I32EnumAttrCase<"TRANSPOSE_INVALID", 0>; +def VHLO_NO_TRANSPOSE : I32EnumAttrCase<"NO_TRANSPOSE", 1>; +def VHLO_TRANSPOSE : I32EnumAttrCase<"TRANSPOSE", 2>; +def VHLO_ADJOINT : I32EnumAttrCase<"ADJOINT", 3>; + +def VHLO_Transpose : I32EnumAttr<"Transpose", + "Transpose options", + [ + VHLO_TRANSPOSE_INVALID, + VHLO_NO_TRANSPOSE, + VHLO_TRANSPOSE, + VHLO_ADJOINT + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vhlo"; +} + +def VHLO_TransposeAttr : EnumAttr; + +def VHLO_RNG_DISTRIBUTION_UNIFORM : I32EnumAttrCase<"UNIFORM", 1>; +def VHLO_RNG_DISTRIBUTION_NORMAL : I32EnumAttrCase<"NORMAL", 2>; + +def VHLO_RngDistribution : I32EnumAttr<"RngDistribution", + "XLA PRNG distribution to be used.", + [ + VHLO_RNG_DISTRIBUTION_UNIFORM, + VHLO_RNG_DISTRIBUTION_NORMAL + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vhlo"; +} + +def VHLO_RngDistributionAttr : EnumAttr; + +def VHLO_RNG_ALGORITHM_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; +def VHLO_RNG_ALGORITHM_THREE_FRY : I32EnumAttrCase<"THREE_FRY", 1>; +def VHLO_RNG_ALGORITHM_PHILOX : I32EnumAttrCase<"PHILOX", 2>; + +def VHLO_RngAlgorithm : I32EnumAttr<"RngAlgorithm", + "XLA PRNG algorithm to be used.", + [ + VHLO_RNG_ALGORITHM_DEFAULT, + VHLO_RNG_ALGORITHM_THREE_FRY, + VHLO_RNG_ALGORITHM_PHILOX + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vhlo"; +} + +def VHLO_RngAlgorithmAttr : EnumAttr; + +#endif // STABLEHLO_DIALECT_VHLO_ENUMS diff --git a/stablehlo/dialect/VhloOps.cpp b/stablehlo/dialect/VhloOps.cpp new file mode 100644 index 0000000000..6c12d27c58 --- /dev/null +++ b/stablehlo/dialect/VhloOps.cpp @@ -0,0 +1,93 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/dialect/VhloOps.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/TypeUtilities.h" +#include "stablehlo/dialect/AssemblyFormat.h" + +// Include order matters +#include "stablehlo/dialect/VhloEnums.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "stablehlo/dialect/VhloAttrs.cpp.inc" +#include "stablehlo/dialect/VhloOpInterfaces.cpp.inc" +#define GET_OP_CLASSES +#include "stablehlo/dialect/VhloOps.cpp.inc" + +namespace mlir { +namespace vhlo { + +using mlir::hlo::parseDimSizes; +using mlir::hlo::printDimSizes; + +//===----------------------------------------------------------------------===// +// StableHLO Dialect Constructor +//===----------------------------------------------------------------------===// + +VhloDialect::VhloDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "stablehlo/dialect/VhloOps.cpp.inc" + >(); + // TODO (gleasonk): addBytecodeInterface(this); + addTypes(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "stablehlo/dialect/VhloAttrs.cpp.inc" + >(); +} + +Type VhloDialect::parseType(DialectAsmParser& parser) const { + StringRef dataType; + if (parser.parseKeyword(&dataType)) return Type(); + + if (dataType == "token") return TokenType::get(getContext()); + parser.emitError(parser.getNameLoc()) << "unknown vhlo type: " << dataType; + return nullptr; +} + +void VhloDialect::printType(Type type, DialectAsmPrinter& os) const { + if (type.isa()) { + os << "token"; + return; + } + os << ""; +} + +// Entry point for Attribute parsing, TableGen generated code will handle the +// dispatch to the individual classes. +Attribute VhloDialect::parseAttribute(DialectAsmParser& parser, + Type type) const { + StringRef attrTag; + Attribute attr; + auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr); + if (parseResult.has_value()) return attr; + parser.emitError(parser.getNameLoc(), "unknown vhlo attribute"); + return Attribute(); +} + +// Entry point for Attribute printing, TableGen generated code will handle the +// dispatch to the individual classes. +void VhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const { + LogicalResult result = generatedAttributePrinter(attr, os); + (void)result; + assert(succeeded(result)); +} + +} // namespace vhlo +} // namespace mlir diff --git a/stablehlo/dialect/VhloOps.h b/stablehlo/dialect/VhloOps.h new file mode 100644 index 0000000000..9fda9e7e3f --- /dev/null +++ b/stablehlo/dialect/VhloOps.h @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_DIALECT_VHLO_OPS_H +#define STABLEHLO_DIALECT_VHLO_OPS_H + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "stablehlo/dialect/Version.h" + +// Include order matters. +#include "stablehlo/dialect/VhloEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "stablehlo/dialect/VhloAttrs.h.inc" + +namespace mlir { +namespace vhlo { + +class VhloDialect : public Dialect { + public: + explicit VhloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "vhlo"; } + + // Parses a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + // Prints a type registered to this dialect. + void printType(Type type, DialectAsmPrinter &os) const override; + + // Parses an attribute registered to this dialect. + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; + + // Prints an attribute registered to this dialect. + void printAttribute(Attribute attr, DialectAsmPrinter &os) const override; + + /// Return a Version representing the current dialect version. + static Version getCurrentVersion() { return Version(0, 4, 0); } + + /// Return a Version representing the minimum supported dialect version. + static Version getMinimumVersion() { return Version(0, 3, 0); } +}; + +class TokenType : public Type::TypeBase { + public: + using Base::Base; +}; + +} // namespace vhlo +} // end namespace mlir + +#include "stablehlo/dialect/VhloOpInterfaces.h.inc" +#define GET_OP_CLASSES +#include "stablehlo/dialect/VhloOps.h.inc" + +#endif // STABLEHLO_DIALECT_VHLO_OPS_H diff --git a/stablehlo/dialect/VhloOps.td b/stablehlo/dialect/VhloOps.td new file mode 100644 index 0000000000..d29253e5d0 --- /dev/null +++ b/stablehlo/dialect/VhloOps.td @@ -0,0 +1,958 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_DIALECT_VHLO_OPS +#define STABLEHLO_DIALECT_VHLO_OPS + +include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Dialect and Ops +//===----------------------------------------------------------------------===// + +def VHLO_Dialect : Dialect { + let name = "vhlo"; + let cppNamespace = "::mlir::vhlo"; + + let description = [{ + A shim opset of Versioned StableHLO ops for versions 0.x.x and 1.x.x. + + Version log: + 0.3.0: Bootstrap from MHLO: https://github.com/openxla/stablehlo/pull/1. + 0.4.0: Add AllGatherOp::use_global_device_ids: https://github.com/openxla/stablehlo/pull/272. + Add CollectivePermuteOp::channel_handle: https://github.com/openxla/stablehlo/pull/388. + Add CustomCallOp::output_operand_aliases: https://github.com/openxla/stablehlo/pull/403. + }]; + + let useDefaultAttributePrinterParser = 0; + let useDefaultTypePrinterParser = 0; +} + +include "stablehlo/dialect/VhloEnums.td" +include "stablehlo/dialect/VhloAttrs.td" + +// Most ops should not use traits. Exceptions are: +// - ReturnOp needs a trait for Terminator. +// - ReduceOp/ReduceWindowOp/ScatterOp need a trait since they have +// multiple variadic arguments. +class VHLO_Op traits = []> : + Op] # traits> { + let extraClassDefinition = [{ + mlir::vhlo::Version $cppClass::getMinVersion() { + return *mlir::vhlo::Version::fromString("}] # minVersion # [{"); + } + mlir::vhlo::Version $cppClass::getMaxVersion() { + if (!strcmp("}] # maxVersion # [{", "current")) return VhloDialect::getCurrentVersion(); + return *mlir::vhlo::Version::fromString("}] # maxVersion # [{"); + } + }]; +} + +def VHLO_ConstantOpV1 : VHLO_Op<"constant"> { + let arguments = (ins ElementsAttr:$value); + let results = (outs VHLO_StaticShapeTensor:$output); +} + +def VHLO_IotaOpV1 : VHLO_Op<"iota"> { + let arguments = (ins I64Attr:$iota_dimension); + let results = (outs VHLO_IntFpOrComplexTensor:$output); +} + +def VHLO_DynamicIotaOpV1 : VHLO_Op<"dynamic_iota"> { + let arguments = (ins VHLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_CreateTokenOpV1 : VHLO_Op<"create_token"> { + let results = (outs VHLO_Token:$output); +} + +def VHLO_AbsOpV1 : VHLO_Op<"abs"> { + let arguments = (ins TensorOf<[VHLO_SInt, VHLO_Float, VHLO_Complex]>:$operand); + let results = (outs TensorOf<[VHLO_SInt, VHLO_Float, VHLO_Complex]>:$result); +} + +def VHLO_CbrtOpV1 : VHLO_Op<"cbrt"> { + let arguments = (ins VHLO_FpTensor:$operand); + let results = (outs VHLO_FpTensor:$result); +} + +def VHLO_CeilOpV1 : VHLO_Op<"ceil"> { + let arguments = (ins VHLO_FpTensor:$operand); + let results = (outs VHLO_FpTensor:$result); +} + +def VHLO_ConvertOpV1 : VHLO_Op<"convert"> { + let arguments = (ins VHLO_Tensor:$operand); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_ClzOpV1 : VHLO_Op<"count_leading_zeros"> { + let arguments = (ins VHLO_IntTensor:$operand); + let results = (outs VHLO_IntTensor:$result); +} + +def VHLO_CosineOpV1 : VHLO_Op<"cosine"> { + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_ExpOpV1 : VHLO_Op<"exponential"> { + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_Expm1OpV1 : VHLO_Op<"exponential_minus_one"> { + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} +def VHLO_FloorOpV1 : VHLO_Op<"floor"> { + let arguments = (ins VHLO_FpTensor:$operand); + let results = (outs VHLO_FpTensor:$result); +} + +def VHLO_ImagOpV1 : VHLO_Op<"imag"> { + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpTensor:$result); +} + +def VHLO_IsFiniteOpV1 : VHLO_Op<"is_finite"> { + let arguments = (ins VHLO_FpTensor:$x); + let results = (outs VHLO_PredTensor:$y); +} + +def VHLO_LogOpV1 : VHLO_Op<"log">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_Log1pOpV1 : VHLO_Op<"log_plus_one">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_LogisticOpV1 : VHLO_Op<"logistic">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_NotOpV1 : VHLO_Op<"not">{ + let arguments = (ins VHLO_PredOrIntTensor:$operand); + let results = (outs VHLO_PredOrIntTensor:$result); +} + +def VHLO_NegOpV1 : VHLO_Op<"negate">{ + let arguments = (ins VHLO_IntFpOrComplexTensor:$operand); + let results = (outs VHLO_IntFpOrComplexTensor:$result); +} + +def VHLO_PopulationCountOpV1 : VHLO_Op<"popcnt">{ + let arguments = (ins VHLO_IntTensor:$operand); + let results = (outs VHLO_IntTensor:$result); +} + +def VHLO_RealOpV1 : VHLO_Op<"real">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpTensor:$result); +} + +def VHLO_RoundOpV1 : VHLO_Op<"round_nearest_afz">{ + let arguments = (ins VHLO_FpTensor:$operand); + let results = (outs VHLO_FpTensor:$result); +} + +def VHLO_RoundNearestEvenOpV1 : VHLO_Op<"round_nearest_even">{ + let arguments = (ins VHLO_FpTensor:$operand); + let results = (outs VHLO_FpTensor:$result); +} + +def VHLO_RsqrtOpV1 : VHLO_Op<"rsqrt">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_SignOpV1 : VHLO_Op<"sign">{ + let arguments = (ins TensorOf<[VHLO_SInt, VHLO_Float, VHLO_Complex]>:$operand); + let results = (outs TensorOf<[VHLO_SInt, VHLO_Float, VHLO_Complex]>:$result); +} + +def VHLO_SineOpV1 : VHLO_Op<"sine">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_SqrtOpV1 : VHLO_Op<"sqrt">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_TanhOpV1 : VHLO_Op<"tanh">{ + let arguments = (ins VHLO_FpOrComplexTensor:$operand); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +// Binary Ops +def VHLO_AddOpV1 : VHLO_Op<"add"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_Atan2OpV1 : VHLO_Op<"atan2"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_ComplexOpV1 : VHLO_Op<"complex"> { + let arguments = (ins VHLO_Fp32Or64Tensor:$lhs, VHLO_Fp32Or64Tensor:$rhs); + let results = (outs VHLO_ComplexTensor:$result); +} +def VHLO_DivOpV1 : VHLO_Op<"divide"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_MaxOpV1 : VHLO_Op<"maximum"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_MinOpV1 : VHLO_Op<"minimum"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_MulOpV1 : VHLO_Op<"multiply"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_PowOpV1 : VHLO_Op<"power"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_RemOpV1 : VHLO_Op<"remainder"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_ShiftLeftOpV1 : VHLO_Op<"shift_left"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_ShiftRightArithmeticOpV1 : VHLO_Op<"shift_right_arithmetic"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_ShiftRightLogicalOpV1 : VHLO_Op<"shift_right_logical"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_SubtractOpV1 : VHLO_Op<"subtract"> { + let arguments = (ins VHLO_Tensor:$lhs, VHLO_Tensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} + +// Logical Ops +def VHLO_AndOpV1 : VHLO_Op<"and"> { + let arguments = (ins VHLO_PredOrIntTensor:$lhs, VHLO_PredOrIntTensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_OrOpV1 : VHLO_Op<"or"> { + let arguments = (ins VHLO_PredOrIntTensor:$lhs, VHLO_PredOrIntTensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} +def VHLO_XorOpV1 : VHLO_Op<"xor"> { + let arguments = (ins VHLO_PredOrIntTensor:$lhs, VHLO_PredOrIntTensor:$rhs); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_ReturnOpV1 : VHLO_Op<"return", "0.3.0", "current", [Terminator]> { + let arguments = (ins + Variadic:$results + ); + let assemblyFormat = "$results attr-dict (`:` type($results)^)?"; +} + +// Communication op definitions. +def VHLO_InfeedOpV1 : VHLO_Op<"infeed"> { + let arguments = (ins + VHLO_Token:$token, + DefaultValuedStrAttr:$infeed_config, + OptionalAttr:$layout + ); + let results = (outs Variadic); +} + +def VHLO_OutfeedOpV1 : VHLO_Op<"outfeed"> { + let arguments = (ins + Variadic:$inputs, + VHLO_Token:$token, + DefaultValuedStrAttr:$outfeed_config + ); + let results = (outs VHLO_Token); +} + +def VHLO_SendOpV1 : VHLO_Op<"send"> { + let arguments = (ins + Variadic:$inputs, + VHLO_Token:$token, + VHLO_ChannelHandle:$channel_handle, + DefaultValuedOptionalAttr:$is_host_transfer + ); + let results = (outs VHLO_Token); +} + +def VHLO_RecvOpV1 : VHLO_Op<"recv"> { + let arguments = (ins + VHLO_Token:$token, + VHLO_ChannelHandle:$channel_handle, + DefaultValuedOptionalAttr:$is_host_transfer + ); + let results = (outs Variadic); +} + +// Parallelism related op definitions. +def VHLO_ReplicaIdOpV1 : VHLO_Op<"replica_id"> { + let results = (outs TensorOf<[UI32]>); +} + +// Control flow op definitions. +def VHLO_AfterAllOpV1 : VHLO_Op<"after_all"> { + let arguments = (ins Variadic:$inputs); + let results = (outs VHLO_Token:$result); +} + +def VHLO_IfOpV1 : VHLO_Op<"if"> { + let arguments = (ins VHLO_PredTensor:$pred); + let regions = (region SizedRegion<1>:$true_branch, + SizedRegion<1>:$false_branch); + let results = (outs Variadic); +} + +def VHLO_CaseOpV1 : VHLO_Op<"case"> { + let arguments = (ins I32Tensor:$index); + let regions = (region VariadicRegion>:$branches); + let results = (outs Variadic); +} + +def VHLO_WhileOpV1 : VHLO_Op<"while"> { + let arguments = (ins Variadic:$operand); + let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); + let results = (outs Variadic); +} + +def VHLO_AllGatherOpV1 : VHLO_Op<"all_gather", "0.3.0", "0.3.0"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64Attr:$all_gather_dim, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_AllGatherOpV2 : VHLO_Op<"all_gather_v2", "0.4.0"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64Attr:$all_gather_dim, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle, + UnitAttr:$use_global_device_ids + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_AllReduceOpV1 : VHLO_Op<"all_reduce"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle, + UnitAttr:$use_global_device_ids + ); + let regions = (region SizedRegion<1>:$computation); + let results = (outs VHLO_Tensor); +} + +def VHLO_ReduceScatterOpV1 : VHLO_Op<"reduce_scatter"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64Attr:$scatter_dimension, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle, + UnitAttr:$use_global_device_ids + ); + let regions = (region SizedRegion<1>:$computation); + let results = (outs VHLO_Tensor); +} + +def VHLO_AllToAllOpV1 : VHLO_Op<"all_to_all"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64Attr:$split_dimension, + I64Attr:$concat_dimension, + I64Attr:$split_count, + I64ElementsAttr:$replica_groups + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_ReduceOpV1 : VHLO_Op<"reduce", "0.3.0", "current", [SameVariadicOperandSize]> { + let arguments = (ins + Variadic:$inputs, + Variadic:$init_values, + I64ElementsAttr:$dimensions + ); + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$body); +} + +//===----------------------------------------------------------------------===// +// VHLO tuple op definitions. +//===----------------------------------------------------------------------===// +def VHLO_GetTupleElementOpV1 : VHLO_Op<"get_tuple_element"> { + let arguments = (ins + VHLO_Tuple:$operand, + I32Attr:$index + ); + let results = (outs VHLO_TensorOrTokenOrTuple); +} + +def VHLO_TupleOpV1 : VHLO_Op<"tuple"> { + let arguments = (ins Variadic:$val); + let results = (outs VHLO_Tuple:$result); +} + +def VHLO_CompareOpV1 : VHLO_Op<"compare"> { + let arguments = (ins + VHLO_Tensor:$lhs, + VHLO_Tensor:$rhs, + VHLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type + ); + let results = (outs VHLO_PredTensor); +} + +// Slice ops +def VHLO_SliceOpV1 : VHLO_Op<"slice"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$start_indices, + I64ElementsAttr:$limit_indices, + I64ElementsAttr:$strides + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_DynamicSliceOpV1 : VHLO_Op<"dynamic_slice"> { + let arguments = (ins + VHLO_Tensor:$operand, + Variadic:$start_indices, + I64ElementsAttr:$slice_sizes + ); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_DynamicUpdateSliceOpV1 : VHLO_Op<"dynamic_update_slice"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_Tensor:$update, + Variadic:$start_indices + ); + let results = (outs VHLO_Tensor:$result); +} + +// Other op definitions. +def VHLO_BatchNormGradOpV1 : VHLO_Op<"batch_norm_grad"> { + let arguments = (ins + RankedTensorOf<[VHLO_Float]>:$operand, + 1DTensorOf<[VHLO_Float]>:$scale, + 1DTensorOf<[VHLO_Float]>:$mean, + 1DTensorOf<[VHLO_Float]>:$variance, + RankedTensorOf<[VHLO_Float]>:$grad_output, + F32Attr:$epsilon, + I64Attr:$feature_index + ); + let results = (outs + RankedTensorOf<[VHLO_Float]>:$grad_operand, + 1DTensorOf<[VHLO_Float]>:$grad_scale, + 1DTensorOf<[VHLO_Float]>:$grad_offset); +} + +def VHLO_BatchNormInferenceOpV1 : VHLO_Op<"batch_norm_inference"> { + let arguments = (ins + RankedTensorOf<[VHLO_Float]>:$operand, + 1DTensorOf<[VHLO_Float]>:$scale, + 1DTensorOf<[VHLO_Float]>:$offset, + 1DTensorOf<[VHLO_Float]>:$mean, + 1DTensorOf<[VHLO_Float]>:$variance, + F32Attr:$epsilon, + I64Attr:$feature_index + ); + let results = (outs RankedTensorOf<[VHLO_Float]>:$result); +} + +def VHLO_BatchNormTrainingOpV1 : VHLO_Op<"batch_norm_training"> { + let arguments = (ins + RankedTensorOf<[VHLO_Float]>:$operand, + 1DTensorOf<[VHLO_Float]>:$scale, + 1DTensorOf<[VHLO_Float]>:$offset, + F32Attr:$epsilon, + I64Attr:$feature_index + ); + let results = (outs + RankedTensorOf<[VHLO_Float]>:$output, + 1DTensorOf<[VHLO_Float]>:$batch_mean, + 1DTensorOf<[VHLO_Float]>:$batch_var); +} + +def VHLO_BitcastConvertOpV1 : VHLO_Op<"bitcast_convert"> { + let arguments = (ins VHLO_Tensor:$operand); + let results = (outs VHLO_Tensor); +} + +def VHLO_BroadcastOpV1 : VHLO_Op<"broadcast"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$broadcast_sizes + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_BroadcastInDimOpV1 : VHLO_Op<"broadcast_in_dim"> { + let arguments = (ins + VHLO_Tensor:$operand, + BroadcastDimAttr:$broadcast_dimensions + ); + let results = (outs VHLO_StaticShapeTensor); +} + +def VHLO_DynamicBroadcastInDimOpV1 : VHLO_Op<"dynamic_broadcast_in_dim"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_DimensionTensor:$output_dimensions, + BroadcastDimAttr:$broadcast_dimensions, + OptionalAttr:$known_expanding_dimensions, + OptionalAttr:$known_nonexpanding_dimensions + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_CholeskyOpV1 : VHLO_Op<"cholesky"> { + let arguments = (ins + VHLO_FpOrComplexTensor:$a, + DefaultValuedOptionalAttr:$lower + ); + let results = (outs VHLO_FpOrComplexTensor:$result); +} + +def VHLO_ClampOpV1 : VHLO_Op<"clamp"> { + let arguments = (ins + VHLO_Tensor:$min, + VHLO_Tensor:$operand, + VHLO_Tensor:$max + ); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_ConcatenateOpV1 : VHLO_Op<"concatenate"> { + let arguments = (ins + Variadic:$inputs, + I64Attr:$dimension + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_CollectivePermuteOpV1 : VHLO_Op<"collective_permute", "0.3.0", "0.3.0"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$source_target_pairs + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_CollectivePermuteOpV2 : VHLO_Op<"collective_permute_v2", "0.4.0"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$source_target_pairs, + OptionalAttr:$channel_handle + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_ConvolutionOpV1 : VHLO_Op<"convolution"> { + let arguments = !con( + (ins + VHLO_Tensor:$lhs, + VHLO_Tensor:$rhs), + VHLO_ConvolutionAttributes.attributes); + let results = (outs VHLO_Tensor); +} + +def VHLO_CrossReplicaSumOpV1 : VHLO_Op<"cross-replica-sum"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$replica_groups + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_CustomCallOpV1 : VHLO_Op<"custom_call", "0.3.0", "0.3.0"> { + let arguments = (ins + Variadic:$inputs, + StrAttr:$call_target_name, + DefaultValuedOptionalAttr:$has_side_effect, + DefaultValuedStrAttr:$backend_config, + DefaultValuedOptionalAttr< + VHLO_CustomCallApiVersionAttr, + "::mlir::vhlo::CustomCallApiVersion::API_VERSION_ORIGINAL">: + $api_version, + DefaultValuedOptionalAttr:$called_computations, + OptionalAttr:$operand_layouts, + OptionalAttr:$result_layouts + ); + let results = (outs Variadic); +} + +def VHLO_CustomCallOpV2: VHLO_Op<"custom_call_v2", "0.4.0"> { + let arguments = (ins + Variadic:$inputs, + StrAttr:$call_target_name, + DefaultValuedOptionalAttr:$has_side_effect, + DefaultValuedStrAttr:$backend_config, + DefaultValuedOptionalAttr< + VHLO_CustomCallApiVersionAttr, + "::mlir::vhlo::CustomCallApiVersion::API_VERSION_ORIGINAL">: + $api_version, + DefaultValuedOptionalAttr:$called_computations, + OptionalAttr:$operand_layouts, + OptionalAttr:$result_layouts, + DefaultValuedOptionalAttr< + TypedArrayAttrBase< + VHLO_OutputOperandAlias, + "Aliasing attribute for outputs and operands of CustomCall">, + "{}">:$output_operand_aliases + ); + let results = (outs Variadic); +} + +def VHLO_DotOpV1 : VHLO_Op<"dot"> { + let arguments = ( + ins VHLO_Tensor:$lhs, + VHLO_Tensor:$rhs, + VHLO_PrecisionConfigAttr:$precision_config + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_DotGeneralOpV1 : VHLO_Op<"dot_general"> { + let arguments = (ins + VHLO_Tensor:$lhs, + VHLO_Tensor:$rhs, + VHLO_DotDimensionNumbers:$dot_dimension_numbers, + VHLO_PrecisionConfigAttr:$precision_config + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_EinsumOpV1 : VHLO_Op<"einsum"> { + let arguments = (ins + VHLO_Tensor:$lhs, + VHLO_Tensor:$rhs, + StrAttr:$einsum_config + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_UnaryEinsumOpV1 : VHLO_Op<"unary_einsum"> { + let arguments = (ins + VHLO_Tensor:$operand, + StrAttr:$einsum_config + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_FftOpV1 : VHLO_Op<"fft"> { + let arguments = (ins + VHLO_FpOrComplexTensor:$operand, + VHLO_FftTypeAttr:$fft_type, + I64ElementsAttr:$fft_length + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_GatherOpV1 : VHLO_Op<"gather"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_IntTensor:$start_indices, + VHLO_GatherDimensionNumbers:$dimension_numbers, + I64ElementsAttr:$slice_sizes, + DefaultValuedOptionalAttr:$indices_are_sorted + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_GetDimensionSizeOpV1 : VHLO_Op<"get_dimension_size"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64Attr:$dimension + ); + let results = (outs I32Tensor); +} + +def VHLO_MapOpV1 : VHLO_Op<"map"> { + let arguments = (ins + Variadic:$inputs, + I64ElementsAttr:$dimensions + ); + let regions = (region SizedRegion<1>:$computation); + let results = (outs VHLO_Tensor); +} + +def VHLO_ReshapeOpV1 : VHLO_Op<"reshape"> { + let arguments = (ins VHLO_Tensor:$operand); + let results = (outs VHLO_StaticShapeTensor); +} + +def VHLO_DynamicReshapeOpV1 : VHLO_Op<"dynamic_reshape"> { + let arguments = (ins VHLO_Tensor:$operand, VHLO_DimensionTensor:$output_shape); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_ScatterOpV1 : VHLO_Op<"scatter", "0.3.0", "current", [SameVariadicOperandSize]> { + let arguments = (ins + Variadic:$inputs, + TensorOf<[AnyInteger, Index]>:$scatter_indices, + Variadic:$updates, + VHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, + DefaultValuedOptionalAttr:$indices_are_sorted, + DefaultValuedOptionalAttr:$unique_indices + ); + let regions = (region SizedRegion<1>:$update_computation); + let results = (outs Variadic); +} + +def VHLO_SelectOpV1 : VHLO_Op<"select"> { + let arguments = (ins + VHLO_PredTensor:$pred, + VHLO_Tensor:$on_true, + VHLO_Tensor:$on_false + ); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_SelectAndScatterOpV1 : VHLO_Op<"select_and_scatter"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_Tensor:$source, + VHLO_Tensor:$init_value, + OptionalAttr:$window_dimensions, + OptionalAttr:$window_strides, + OptionalAttr:$padding + ); + let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); + let results = (outs VHLO_Tensor); +} + +def VHLO_SetDimensionSizeOpV1 : VHLO_Op<"set_dimension_size"> { + let arguments = (ins + VHLO_Tensor:$operand, + I32Tensor:$size, + I64Attr:$dimension + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_SortOpV1 : VHLO_Op<"sort"> { + let arguments = (ins + Variadic:$inputs, + DefaultValuedOptionalAttr:$dimension, + DefaultValuedOptionalAttr:$is_stable + ); + let regions = (region SizedRegion<1>:$comparator); + let results = (outs Variadic); +} + +def VHLO_ReverseOpV1 : VHLO_Op<"reverse"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$dimensions + ); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_PadOpV1 : VHLO_Op<"pad"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_Tensor:$padding_value, + I64ElementsAttr:$edge_padding_low, + I64ElementsAttr:$edge_padding_high, + I64ElementsAttr:$interior_padding + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_TraceOpV1 : VHLO_Op<"trace"> { + let arguments = (ins + VHLO_Tensor:$operand, + StrAttr:$tag + ); +} + +def VHLO_TransposeOpV1 : VHLO_Op<"transpose"> { + let arguments = (ins + VHLO_Tensor:$operand, + I64ElementsAttr:$permutation + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_TriangularSolveOpV1 : VHLO_Op<"triangular_solve"> { + let arguments = (ins + VHLO_FpOrComplexTensor:$a, + VHLO_FpOrComplexTensor:$b, + BoolAttr:$left_side, + BoolAttr:$lower, + BoolAttr:$unit_diagonal, + VHLO_TransposeAttr:$transpose_a + ); + let results = (outs VHLO_FpOrComplexTensor); +} + +def VHLO_ReduceWindowOpV1 : VHLO_Op<"reduce_window", "0.3.0", "current", [SameVariadicOperandSize]> { + let arguments = (ins + Variadic:$inputs, + Variadic:$init_values, + I64ElementsAttr:$window_dimensions, + // If strides or dilations attributes are missing then the default value is + // one for each of the operand dimensions. Similarly, padding values are zero + // for both low and high in each of the dimensions, if not specified. + OptionalAttr:$window_strides, + OptionalAttr:$base_dilations, + OptionalAttr:$window_dilations, + OptionalAttr:$padding + ); + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$body); +} + +def VHLO_TorchIndexSelectOpV1 : VHLO_Op<"torch_index_select"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_Tensor:$index, + I64Attr:$dim, + I64Attr:$batch_dims + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_OptimizationBarrierOpV1 : VHLO_Op<"optimization_barrier"> { + let arguments = (ins Variadic:$operand); + let results = (outs Variadic:$result); +} + +//===----------------------------------------------------------------------===// +// VHLO RNG Operators. +//===----------------------------------------------------------------------===// + +def VHLO_RngOpV1 : VHLO_Op<"rng"> { + let arguments = (ins + 0DTensorOf<[VHLO_Pred, VHLO_Int, VHLO_Float]>:$a, + 0DTensorOf<[VHLO_Pred, VHLO_Int, VHLO_Float]>:$b, + VHLO_DimensionTensor:$shape, + VHLO_RngDistributionAttr:$rng_distribution + ); + let results = (outs VHLO_PredIntOrFpTensor:$result); +} + +def VHLO_RngBitGeneratorOpV1 : VHLO_Op<"rng_bit_generator"> { + let arguments = (ins + VHLO_RngAlgorithmAttr:$rng_algorithm, + VHLO_IntOrFpTensor:$initial_state + ); + let results = (outs + VHLO_IntOrFpTensor:$output_state, + VHLO_IntOrFpTensor:$output + ); +} + +// Quantize Ops +def VHLO_UniformQuantizeOpV1 : VHLO_Op<"uniform_quantize"> { + let arguments = (ins TensorOf<[F32, BF16, VHLO_QuantizedInt]>:$operand); + let results = (outs VHLO_QuantizedIntTensor:$result); +} + +def VHLO_UniformDequantizeOpV1 : VHLO_Op<"uniform_dequantize"> { + let arguments = (ins VHLO_QuantizedIntTensor:$operand); + let results = (outs TensorOf<[F32, BF16]>:$result); +} + +def VHLO_ReducePrecisionOpV1 : VHLO_Op<"reduce_precision"> { + let arguments = (ins + VHLO_FpTensor:$operand, + I32Attr:$exponent_bits, + I32Attr:$mantissa_bits + ); + let results = (outs VHLO_FpTensor:$output); +} + +def VHLO_RealDynamicSliceOpV1 : VHLO_Op<"real_dynamic_slice"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_DimensionTensor:$start_indices, + VHLO_DimensionTensor:$limit_indices, + VHLO_DimensionTensor:$strides + ); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_DynamicPadOpV1 : VHLO_Op<"dynamic_pad"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_Tensor:$padding_value, + VHLO_DimensionTensor:$edge_padding_low, + VHLO_DimensionTensor:$edge_padding_high, + VHLO_DimensionTensor:$interior_padding + ); + let results = (outs VHLO_Tensor:$result); +} + +def VHLO_DynamicGatherOpV1 : VHLO_Op<"dynamic_gather"> { + let arguments = (ins + VHLO_Tensor:$operand, + VHLO_IntTensor:$start_indices, + VHLO_IntTensor:$slice_sizes, + VHLO_GatherDimensionNumbers:$dimension_numbers, + DefaultValuedOptionalAttr:$indices_are_sorted + ); + let results = (outs VHLO_Tensor); +} + +def VHLO_DynamicConvOpV1 : VHLO_Op<"dynamic_conv"> { + let arguments = !con( + (ins + VHLO_Tensor:$lhs, + VHLO_Tensor:$rhs, + VHLO_Tensor:$d_padding), + VHLO_ConvolutionAttributes.attributes); + let results = (outs VHLO_Tensor); +} + +def VHLO_ComputeReshapeShapeOpV1 : VHLO_Op<"compute_reshape_shape"> { + let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); + let results = (outs 1DTensorOf<[AnyInteger, Index]>:$result); +} + +def VHLO_CstrReshapableOpV1 : VHLO_Op<"cstr_reshapable"> { + let results = (outs Shape_WitnessType:$result); + let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); +} + +#endif // STABLEHLO_DIALECT_STABLEHLO_OPS diff --git a/stablehlo/tests/legalize_stablehlo_to_vhlo.mlir b/stablehlo/tests/legalize_stablehlo_to_vhlo.mlir new file mode 100644 index 0000000000..449c58c9fe --- /dev/null +++ b/stablehlo/tests/legalize_stablehlo_to_vhlo.mlir @@ -0,0 +1,1770 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --mlir-print-op-generic --split-input-file %s | FileCheck %s +// RUN: diff <(stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-legalize-to-stablehlo stablehlo/tests/legalize_stablehlo_to_vhlo.mlir) <(stablehlo-opt stablehlo/tests/legalize_stablehlo_to_vhlo.mlir) + +// ============ ATTRIBUTES ============ + +func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_eq" + +func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_ne" + +func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_ge" + +func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_gt" + +func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_le" + +func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_lt" + +func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_notype" + +func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_float" + +func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_totalorder" + +func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_signed" + +func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_unsigned" + +// ConvDimensionNumbers aka #stablehlo.conv is covered below. + +func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = 0 : i32 + api_version = 0 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_unspecified" + +func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = 1 : i32 + api_version = 1 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_original" + +func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = 2 : i32 + api_version = 2 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_status_returning" + +func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = 3 : i32 + api_version = 3 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" + +// DotDimensionNumbers aka #stablehlo.dot is covered below. + +func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} +// CHECK-LABEL: "attr_fft_type_fft" + +func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} +// CHECK-LABEL: "attr_fft_type_ifft" + +func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<9xcomplex> + func.return %0 : tensor<9xcomplex> +} +// CHECK-LABEL: "attr_fft_type_rfft" + +func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "attr_fft_type_irfft" + +// GatherDimensionNumbers aka #stablehlo.gather is covered below. + +func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = [#vhlo] + precision_config = [#stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "attr_precision_config_default" + +func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = [#vhlo] + precision_config = [#stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "attr_precision_config_high" + +func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = [#vhlo] + precision_config = [#stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "attr_precision_config_highest" + +func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "attr_rng_algorithm_default" + +func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "attr_rng_algorithm_three_fry" + +func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "attr_rng_algorithm_philox" + +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_rng_distribution_uniform" + +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_rng_distribution_normal" + +// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. + +func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "attr_transpose_no_transpose" + +func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "attr_transpose_transpose" + +func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "attr_transpose_adjoint" + +// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. + +func.func @attr_type_extensions_bounds( + %arg0: tensor>) + -> tensor> { + // CHECK: "func.return"(%arg0) : (tensor>) -> () + func.return %arg0 : tensor> +} + +// CHECK-LABEL: "attr_type_extensions_bounds" + +// ============ OPS ============ + +func.func @op_abs(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_abs" + +func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_add" + +func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.after_all"(%arg0) : (!vhlo.token) -> !vhlo.token + %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_after_all" + +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v2"(%arg0) { + // CHECK-SAME: all_gather_dim = 1 : i64, + // CHECK-SAME: channel_handle = #vhlo.channel_handle, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + // CHECK-SAME: use_global_device_ids + // CHECK-SAME: } : (tensor<16x8xf32>) -> tensor<16x16xf32> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_all_gather" + +func.func @op_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add"(%[[ARG1]], %[[ARG2]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: channel_handle = #vhlo.channel_handle, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + // CHECK-SAME: use_global_device_ids + // CHECK-SAME: } : (tensor) -> tensor + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_all_reduce" + +func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all"(%arg0) { + // CHECK-SAME: concat_dimension = 0 : i64, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + // CHECK-SAME: split_count = 4 : i64, + // CHECK-SAME: split_dimension = 1 : i64 + // CHECK-SAME: } : (tensor<4x16xf32>) -> tensor<16x4xf32> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} +// CHECK-LABEL: "op_all_to_all" + +func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_and" + +func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_atan2" + +func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + // CHECK-SAME: epsilon = 1.000000e-03 : f32, + // CHECK-SAME: feature_index = 0 : i64 + // CHECK-SAME: } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} +// CHECK-LABEL: "op_batch_norm_grad" + +func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + // CHECK: "vhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + // CHECK-SAME: epsilon = 1.000000e-03 : f32, + // CHECK-SAME: feature_index = 0 : i64 + // CHECK-SAME: } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} +// CHECK-LABEL: "op_batch_norm_inference" + +func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_training"(%arg0, %arg1, %arg2) { + // CHECK-SAME: epsilon = 1.000000e-03 : f32, + // CHECK-SAME: feature_index = 0 : i64 + // CHECK-SAME: } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} +// CHECK-LABEL: "op_batch_norm_training" + +func.func @op_bitcast_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.bitcast_convert"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_bitcast_convert" + +func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_in_dim"(%arg0) { + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16x16xf32> + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<1> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_broadcast_in_dim" + +func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast"(%arg0) { + // CHECK-SAME: broadcast_sizes = dense<16> : tensor<1xi64> + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16x16xf32> + %0 = "stablehlo.broadcast"(%arg0) { + broadcast_sizes = dense<16> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_broadcast" + +func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case"(%arg0) ({ + // CHECK-NEXT: "vhlo.return"(%arg1) : (tensor) -> () + // CHECK-NEXT: }) : (tensor) -> tensor + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_case" + +func.func @op_cbrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cbrt"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_cbrt" + +func.func @op_ceil(%arg0: tensor) -> tensor { + // CHECK: "vhlo.ceil"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_ceil" + +func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky"(%arg0) { + // CHECK-SAME: lower = true + // CHECK-SAME: } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + %0 = "stablehlo.cholesky"(%arg0) { + lower = true + } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} +// CHECK-LABEL: "op_cholesky" + +func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_clamp" + +func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { + // CHECK: "vhlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_count_leading_zeros" + +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v2"(%arg0) { + // CHECK-SAME: channel_handle = #vhlo.channel_handle, + // CHECK-SAME{LITERAL}: source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + // CHECK-SAME: } : (tensor<16x8xf32>) -> tensor<16x8xf32> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} +// CHECK-LABEL: "op_collective_permute" + +func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare"(%arg0, %arg1) { + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: } : (tensor, tensor) -> tensor + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_compare" + +func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK: "vhlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "op_complex" + +func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { + // CHECK: "vhlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> + %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> + func.return %0 : tensor<1xindex> +} +// CHECK-LABEL: "op_compute_reshape_shape" + +func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.concatenate"(%arg0, %arg1) { + // CHECK-SAME: dimension = 0 : i64 + // CHECK-SAME: } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + %0 = "stablehlo.concatenate"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_concatenate" + +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: "vhlo.constant"() { + // CHECK-SAME: value = dense<0.000000e+00> : tensor + // CHECK-SAME: } : () -> tensor + %0 = "stablehlo.constant"() { + value = dense<0.0> : tensor + } : () -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_constant" + +func.func @op_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.convert"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_convert" + +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // CHECK: "vhlo.convolution"(%arg0, %arg1) { + // CHECK-SAME: batch_group_count = 1 : i64, + // CHECK-SAME: dimension_numbers = #vhlo.conv, + // CHECK-SAME: feature_group_count = 1 : i64, + // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, + // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, + // CHECK-SAME: precision_config = [#vhlo, #vhlo], + // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, + // CHECK-SAME: window_reversal = dense : tensor<2xi1>, + // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> + // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<1> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<1> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} +// CHECK-LABEL: "op_convolution" + +func.func @op_cosine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cosine"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_cosine" + +func.func @op_create_token() -> !stablehlo.token { + // CHECK: "vhlo.create_token"() : () -> !vhlo.token + %0 = "stablehlo.create_token"() : () -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_create_token" + +func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cross-replica-sum"(%arg0) { + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + // CHECK-SAME: } : (tensor) -> tensor + %0 = "stablehlo.cross-replica-sum"(%arg0) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_cross_replica_sum" + +func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { + // CHECK: "vhlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness + %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness + func.return %0 : !shape.witness +} +// CHECK-LABEL: "op_cstr_reshapable" + +func.func @called_computation() { func.return } + +func.func @op_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v2"(%arg0) { + // CHECK-SAME: api_version = 1 : i32, + // CHECK-SAME: backend_config = "", + // CHECK-SAME: call_target_name = "foo", + // CHECK-SAME: called_computations = [@foo], + // CHECK-SAME: has_side_effect = false, + // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>], + // CHECK-SAME: output_operand_aliases = [ + // CHECK-SAME: #vhlo.output_operand_alias< + // CHECK-SAME: outputTupleIndices = [], + // CHECK-SAME: operandIndex = 0, + // CHECK-SAME: operandTupleIndices = []>] + // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] + // CHECK-SAME: } : (tensor) -> tensor + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = false, + backend_config = "", + api_version = 1 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_custom_call" + +func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_divide" + +func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general"(%arg0, %arg1) { + // CHECK-SAME: dot_dimension_numbers = #vhlo.dot< + // CHECK-SAME: lhsBatchingDimensions = [0], + // CHECK-SAME: rhsBatchingDimensions = [0], + // CHECK-SAME: lhsContractingDimensions = [2], + // CHECK-SAME: rhsContractingDimensions = [1] + // CHECK-SAME: >, + // CHECK-SAME: precision_config = [] + // CHECK-SAME: } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + precision_config = [] + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} +// CHECK-LABEL: "op_dot_general" + +func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot"(%arg0, %arg1) { + // CHECK-SAME: precision_config = [] + // CHECK-SAME: } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + %0 = "stablehlo.dot"(%arg0, %arg1) { + precision_config = [] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "op_dot" + +func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + // CHECK-SAME: broadcast_dimensions = dense<1> : tensor<1xi64>, + // CHECK-SAME: known_expanding_dimensions = dense<> : tensor<0xi64>, + // CHECK-SAME: known_nonexpanding_dimensions = dense<0> : tensor<1xi64> + // CHECK-SAME: } : (tensor, tensor<2xindex>) -> tensor + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<1> : tensor<1xi64>, + known_expanding_dimensions = dense<[]> : tensor<0xi64>, + known_nonexpanding_dimensions = dense<0> : tensor<1xi64> + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_broadcast_in_dim" + +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv"(%arg0, %arg1, %arg2) { + // CHECK-SAME: batch_group_count = 1 : i64, + // CHECK-SAME: dimension_numbers = #vhlo.conv, + // CHECK-SAME: feature_group_count = 1 : i64, + // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, + // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, + // CHECK-SAME: precision_config = [#vhlo, #vhlo], + // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, + // CHECK-SAME: window_reversal = dense : tensor<2xi1>, + // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> + // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + window_strides = dense<1> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<1> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} +// CHECK-LABEL: "op_dynamic_conv" + +func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather"(%arg0, %arg1, %arg2) { + // CHECK-SAME: dimension_numbers = #vhlo.gather< + // CHECK-SAME: offsetDims = [2], + // CHECK-SAME: collapsedSliceDims = [0, 1], + // CHECK-SAME: startIndexMap = [0, 1], + // CHECK-SAME: indexVectorDim = 2 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false + // CHECK-SAME: } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + indices_are_sorted = false + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} +// CHECK-LABEL: "op_dynamic_gather" + +func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_iota"(%arg0) { + // CHECK-SAME: iota_dimension = 0 : i64 + // CHECK-SAME: } : (tensor<1xindex>) -> tensor + %0 = "stablehlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_iota" + +func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_pad" + +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_reshape" + +func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { + // CHECK: "vhlo.dynamic_slice"(%arg0, %arg1) { + // CHECK-SAME: slice_sizes = dense<4> : tensor<1xi64> + // CHECK-SAME: } : (tensor<16xf32>, tensor) -> tensor<4xf32> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { + slice_sizes = dense<4> : tensor<1xi64> + } : (tensor<16xf32>, tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} +// CHECK-LABEL: "op_dynamic_slice" + +func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_dynamic_update_slice" + +func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.einsum"(%arg0, %arg1) { + // CHECK-SAME: einsum_config = "ab,bc->ac" + // CHECK-SAME: } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + %0 = "stablehlo.einsum"(%arg0, %arg1) { + einsum_config = "ab,bc->ac" + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "op_einsum" + +func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_exponential_minus_one" + +func.func @op_exponential(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_exponential" + +func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + // CHECK: "vhlo.fft"(%arg0) { + // CHECK-SAME: fft_length = dense<16> : tensor<1xi64>, + // CHECK-SAME: fft_type = #vhlo + // CHECK-SAME: } : (tensor<16xcomplex>) -> tensor<16xcomplex> + %0 = "stablehlo.fft"(%arg0) { + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} +// CHECK-LABEL: "op_fft" + +func.func @op_floor(%arg0: tensor) -> tensor { + // CHECK: "vhlo.floor"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_floor" + +func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather"(%arg0, %arg1) { + // CHECK-SAME: dimension_numbers = #vhlo.gather< + // CHECK-SAME: offsetDims = [2], + // CHECK-SAME: collapsedSliceDims = [0, 1], + // CHECK-SAME: startIndexMap = [0, 1], + // CHECK-SAME: indexVectorDim = 2 + // CHECK-SAME: >, + // CHECK-SAME: indices_are_sorted = false, + // CHECK-SAME: slice_sizes = dense<1> : tensor<3xi64> + // CHECK-SAME: } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<3xi64>, + indices_are_sorted = false + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} +// CHECK-LABEL: "op_gather" + +func.func @op_get_dimension_size(%arg0: tensor) -> tensor { + // CHECK: "vhlo.get_dimension_size"(%arg0) { + // CHECK-SAME: dimension = 0 : i64 + // CHECK-SAME: } : (tensor) -> tensor + %0 = "stablehlo.get_dimension_size"(%arg0) { + dimension = 0 : i64 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_get_dimension_size" + +func.func @op_get_tuple_element(%arg0: tuple>) -> tensor { + // CHECK: "vhlo.get_tuple_element"(%arg0) { + // CHECK-SAME: index = 0 : i32 + // CHECK-SAME: } : (tuple>) -> tensor + %0 = "stablehlo.get_tuple_element"(%arg0) { + index = 0 : i32 + } : (tuple>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_get_tuple_element" + +func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.if"(%arg0) ({ + // CHECK-NEXT: "vhlo.return"(%arg1) : (tensor) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "vhlo.return"(%arg2) : (tensor) -> () + // CHECK-NEXT: }) : (tensor) -> tensor + %0 = "stablehlo.if"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + "stablehlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_if" + +func.func @op_imag(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.imag"(%arg0) : (tensor>) -> tensor + %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_imag" + +func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed"(%arg0) { + // CHECK-SAME: infeed_config = "", + // CHECK-SAME{LITERAL}: layout = [[]] + // CHECK-SAME: } : (!vhlo.token) -> (tensor, !vhlo.token) + %0:2 = "stablehlo.infeed"(%arg0) { + infeed_config = "", + layout = [[]] + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} +// CHECK-LABEL: "op_infeed" + +func.func @op_iota() -> tensor<16xf32> { + // CHECK: "vhlo.iota"() { + // CHECK-SAME: iota_dimension = 0 : i64 + // CHECK-SAME: } : () -> tensor<16xf32> + %0 = "stablehlo.iota"() { + iota_dimension = 0 : i64 + } : () -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_iota" + +func.func @op_is_finite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.is_finite"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_is_finite" + +func.func @op_log(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_log" + +func.func @op_log_plus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_plus_one"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_log_plus_one" + +func.func @op_logistic(%arg0: tensor) -> tensor { + // CHECK: "vhlo.logistic"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_logistic" + +func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.map"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs"(%[[ARG1]]) : (tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: dimensions = dense<0> : tensor<1xi64> + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + %0 = "stablehlo.map"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_map" + +func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_maximum" + +func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_minimum" + +func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_multiply" + +func.func @op_negate(%arg0: tensor) -> tensor { + // CHECK: "vhlo.negate"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_negate" + +func.func @op_not(%arg0: tensor) -> tensor { + // CHECK: "vhlo.not"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_not" + +func.func @op_optimization_barrier(%arg0: tensor) -> tensor { + // CHECK: "vhlo.optimization_barrier"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_optimization_barrier" + +func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_or" + +func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed"(%arg0, %arg1) { + // CHECK-SAME: outfeed_config = "" + // CHECK-SAME: } : (tensor, !vhlo.token) -> !vhlo.token + %0 = "stablehlo.outfeed"(%arg0, %arg1) { + outfeed_config = "" + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_outfeed" + +func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.pad"(%arg0, %arg1) { + // CHECK-SAME: edge_padding_high = dense<4> : tensor<1xi64>, + // CHECK-SAME: edge_padding_low = dense<4> : tensor<1xi64>, + // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> + // CHECK-SAME: } : (tensor<8xf32>, tensor) -> tensor<16xf32> + %0 = "stablehlo.pad"(%arg0, %arg1) { + edge_padding_high = dense<4> : tensor<1xi64>, + edge_padding_low = dense<4> : tensor<1xi64>, + interior_padding = dense<0> : tensor<1xi64> + } : (tensor<8xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_pad" + +func.func @op_popcnt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.popcnt"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_popcnt" + +func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_power" + +func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_real_dynamic_slice" + +func.func @op_real(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.real"(%arg0) : (tensor>) -> tensor + %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_real" + +func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv"(%arg0) { + // CHECK-SAME: channel_handle = #vhlo.channel_handle, + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: } : (!vhlo.token) -> (tensor, !vhlo.token) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} +// CHECK-LABEL: "op_recv" + +func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_reduce" + +func.func @op_reduce_precision(%arg0: tensor) -> tensor { + // CHECK: "vhlo.reduce_precision"(%arg0) { + // CHECK-SAME: exponent_bits = 8 : i32, + // CHECK-SAME: mantissa_bits = 10 : i32 + // CHECK-SAME: } : (tensor) -> tensor + %0 = "stablehlo.reduce_precision"(%arg0) { + exponent_bits = 8 : i32, + mantissa_bits = 10 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_reduce_precision" + +func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add"(%[[ARG1]], %[[ARG2]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: channel_handle = #vhlo.channel_handle, + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + // CHECK-SAME: scatter_dimension = 0 : i64 + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + scatter_dimension = 0 : i64 + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_reduce_scatter" + +func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x5x8x7xf32> { + // CHECK: "vhlo.reduce_window"(%arg0, %arg1) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum"(%[[ARG2]], %[[ARG3]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>, + // CHECK-SAME{LITERAL}: padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>, + // CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + // CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64> + // CHECK-SAME: } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x5x8x7xf32> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, + base_dilations = dense<[1, 1, 1, 1]> : tensor<4xi64>, + window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x5x8x7xf32> + func.return %0 : tensor<2x5x8x7xf32> +} +// CHECK-LABEL: "op_reduce_window" + +func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_remainder" + +func.func @op_replica_id() -> tensor { + // CHECK: "vhlo.replica_id"() : () -> tensor + %0 = "stablehlo.replica_id"() : () -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_replica_id" + +func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { + // CHECK: "vhlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} +// CHECK-LABEL: "op_reshape" + +func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case"(%arg0) ({ + // CHECK-NEXT: "vhlo.return"(%arg1) : (tensor) -> () + // CHECK-NEXT: }) : (tensor) -> tensor + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_return" + +func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reverse"(%arg0) { + // CHECK-SAME: dimensions = dense<0> : tensor<1xi64> + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + %0 = "stablehlo.reverse"(%arg0) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_reverse" + +func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { + // CHECK: "vhlo.rng_bit_generator"(%arg0) { + // CHECK-SAME: rng_algorithm = #vhlo + // CHECK-SAME: } : (tensor) -> (tensor, tensor) + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "op_rng_bit_generator" + +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.rng"(%arg0, %arg1, %arg2) { + // CHECK-SAME: rng_distribution = #vhlo + // CHECK-SAME: } : (tensor, tensor, tensor) -> tensor + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_rng" + +func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_round_nearest_afz" + +func.func @op_round_nearest_even(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_even"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_round_nearest_even" + +func.func @op_rsqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.rsqrt"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_rsqrt" + +func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter"(%arg0, %arg1, %arg2) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: tensor, %[[ARG4:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add"(%[[ARG3]], %[[ARG4]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: indices_are_sorted = true, + // CHECK-SAME: scatter_dimension_numbers = #vhlo.scatter< + // CHECK-SAME: updateWindowDims = [1], + // CHECK-SAME: insertedWindowDims = [0, 1], + // CHECK-SAME: scatterDimsToOperandDims = [0, 1], + // CHECK-SAME: indexVectorDim = 1 + // CHECK-SAME: >, + // CHECK-SAME: unique_indices = true + // CHECK-SAME: } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} +// CHECK-LABEL: "op_scatter" + +func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: tensor, %[[ARG41:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare"(%[[ARG31]], %[[ARG41]]) {compare_type = #vhlo, comparison_direction = #vhlo} : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL11]]) : (tensor) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: tensor, %[[ARG42:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add"(%[[ARG32]], %[[ARG42]]) : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL12]]) : (tensor) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: padding = dense<0> : tensor<4x2xi64>, + // CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + // CHECK-SAME: window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + // CHECK-SAME: } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} +// CHECK-LABEL: "op_select_and_scatter" + +func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_select" + +func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send"(%arg0, %arg1) { + // CHECK-SAME: channel_handle = #vhlo.channel_handle, + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: } : (tensor, !vhlo.token) -> !vhlo.token + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_send" + +func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.set_dimension_size"(%arg0, %arg1) { + // CHECK-SAME: dimension = 0 : i64 + // CHECK-SAME: } : (tensor, tensor) -> tensor<16xf32> + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_set_dimension_size" + +func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_shift_left" + +func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_shift_right_arithmetic" + +func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_shift_right_logical" + +func.func @op_sign(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sign"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_sign" + +func.func @op_sine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sine"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_sine" + +func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { + // CHECK: "vhlo.slice"(%arg0) { + // CHECK-SAME: limit_indices = dense<4> : tensor<1xi64>, + // CHECK-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<4xf32> + %0 = "stablehlo.slice"(%arg0) { + start_indices = dense<0> : tensor<1xi64>, + limit_indices = dense<4> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} +// CHECK-LABEL: "op_slice" + +func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor, %[[ARG2:arg.*]]: tensor): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare"(%[[ARG1]], %[[ARG2]]) {compare_type = #vhlo, comparison_direction = #vhlo} : (tensor, tensor) -> tensor + // CHECK-NEXT: "vhlo.return"(%[[VAL1]]) : (tensor) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: dimension = 0 : i64, + // CHECK-SAME: is_stable = true + // CHECK-SAME: } : (tensor<16xf32>) -> tensor<16xf32> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimension = 0 : i64, + is_stable = true + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_sort" + +func.func @op_sqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sqrt"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_sqrt" + +func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_subtract" + +func.func @op_tanh(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tanh"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_tanh" + +func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { + // CHECK: "vhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: batch_dims = 0 : i64, + // CHECK-SAME: dim = 0 : i64 + // CHECK-SAME: } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + func.return %0 : tensor<2x1x5xf32> +} +// CHECK-LABEL: "op_torch_index_select" + +func.func @op_trace(%arg0: tensor) { + // CHECK: "vhlo.trace"(%arg0) { + // CHECK-SAME: tag = "foo" + // CHECK-SAME: } : (tensor) -> () + "stablehlo.trace"(%arg0) { + tag = "foo" + } : (tensor) -> () + func.return +} +// CHECK-LABEL: "op_trace" + +func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { + // CHECK: "vhlo.transpose"(%arg0) { + // CHECK-SAME: permutation = dense<[1, 0]> : tensor<2xi64> + // CHECK-SAME: } : (tensor<16x8xf32>) -> tensor<8x16xf32> + %0 = "stablehlo.transpose"(%arg0) { + permutation = dense<[1, 0]> : tensor<2xi64> + } : (tensor<16x8xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} +// CHECK-LABEL: "op_transpose" + +func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.triangular_solve"(%arg0, %arg1) { + // CHECK-SAME: left_side = true, + // CHECK-SAME: lower = true, + // CHECK-SAME: transpose_a = #vhlo, + // CHECK-SAME: unit_diagonal = true + // CHECK-SAME: } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_triangular_solve" + +func.func @op_tuple(%arg0: tensor) -> tuple> { + // CHECK: "vhlo.tuple"(%arg0) : (tensor) -> tuple> + %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> + func.return %0 : tuple> +} +// CHECK-LABEL: "op_tuple" + +func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // CHECK: "vhlo.unary_einsum"(%arg0) { + // CHECK-SAME: einsum_config = "ab->a" + // CHECK-SAME: } : (tensor<8x16xf32>) -> tensor<8xf32> + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} +// CHECK-LABEL: "op_unary_einsum" + +func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_uniform_dequantize" + +func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { + // CHECK: "vhlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "op_uniform_quantize" + +func.func @op_while(%arg0: tensor) -> tensor { + // CHECK: "vhlo.while"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor): + // CHECK-NEXT: "vhlo.return"(%[[ARG1]]) : (tensor) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: tensor): + // CHECK-NEXT: "vhlo.return"(%[[ARG1]]) : (tensor) -> () + // CHECK-NEXT: }) : (tensor) -> tensor + %0 = "stablehlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} +// CHECK-LABEL: "op_while" + +func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_xor" + +// ============ TYPES ============ + +func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i1" + +func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i4" + +func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i8" + +func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i16" + +func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i32" + +func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i64" + +func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui4" + +func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui8" + +func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui16" + +func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui32" + +func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui64" + +func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_bf16" + +func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f16" + +func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f32" + +func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f64" + +func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "type_complex_f32" + +func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "type_complex_f64" + +func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs"(%arg0) : (tensor) -> tensor + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_dynamism_ranked" + +func.func @type_dynamism_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "vhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "stablehlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} +// CHECK-LABEL: "type_dynamism_unranked" + +func.func @type_quantization(%arg0: tensor>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_quantization" + +func.func @type_sparsity(%arg0: tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor<16xf32> { + // CHECK: "vhlo.abs"(%arg0) : (tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor<16xf32> + %0 = "stablehlo.abs"(%arg0) : (tensor<16xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "type_sparsity" + +func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "func.return"(%arg0) : (!vhlo.token) -> () + return %arg0 : !stablehlo.token +} +// CHECK: function_type = (!vhlo.token) -> !vhlo.token +// CHECK-LABEL: "type_token_callee" + +func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "func.call"(%arg0) {callee = @type_token_callee} : (!vhlo.token) -> !vhlo.token + %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token + return %0 : !stablehlo.token +} +// CHECK: function_type = (!vhlo.token) -> !vhlo.token +// CHECK-LABEL: "type_token_caller" + +func.func @type_tuple(%arg0: tuple>) -> tuple { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + // CHECK: (tuple>) -> tuple + } : (tuple>) -> tuple + return %0 : tuple +} +// CHECK-LABEL: "type_tuple" diff --git a/stablehlo/tests/ops_chlo_roundtrip.mlir b/stablehlo/tests/ops_chlo_roundtrip.mlir index 62fa45a305..9d8f7cf335 100644 --- a/stablehlo/tests/ops_chlo_roundtrip.mlir +++ b/stablehlo/tests/ops_chlo_roundtrip.mlir @@ -3,7 +3,6 @@ // RUN: stablehlo-opt -emit-bytecode -debug-only=chlo-bytecode %s 2>&1 | (! grep 'Not Implemented') // RUN: stablehlo-opt -emit-bytecode %s | stablehlo-opt -debug-only=chlo-bytecode 2>&1 | (! grep 'Not Implemented') - // CHECK-LABEL: func @chlo_acos( // CHECK-SAME: %[[A:.*]]: tensor<8x8xf64> // CHECK: %[[T:.*]] = chlo.acos %[[A]] : tensor<8x8xf64> -> tensor<8x8xf64> diff --git a/stablehlo/tests/vhlo_to_version_downgrade.mlir b/stablehlo/tests/vhlo_to_version_downgrade.mlir new file mode 100644 index 0000000000..828c569cc1 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_downgrade.mlir @@ -0,0 +1,41 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=0.3.0' %s | FileCheck %s + + +// CHECK-LABEL: @all_gather_to_v1 +func.func @all_gather_to_v1(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK-NEXT: %0 = "vhlo.all_gather"(%arg0) + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: @collective_permute_to_v1 +func.func @collective_permute_to_v1(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK-NEXT: %0 = "vhlo.collective_permute"(%arg0) + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: @custom_call_to_v1 +func.func @custom_call_to_v1(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: %0 = "vhlo.custom_call"(%arg0) + %0 = stablehlo.custom_call @foo(%arg0) : (tensor<2xi1>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK-LABEL: @custom_call_to_v1_empty_output_operand_alias +func.func @custom_call_to_v1_empty_output_operand_alias(%arg0 : tensor) -> tensor { + // CHECK-NEXT: %0 = "vhlo.custom_call"(%arg0) + %0 = stablehlo.custom_call @foo(%arg0) { + has_side_effect = false, + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/stablehlo/tests/vhlo_to_version_downgrade_invalid.mlir b/stablehlo/tests/vhlo_to_version_downgrade_invalid.mlir new file mode 100644 index 0000000000..178b8b7814 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_downgrade_invalid.mlir @@ -0,0 +1,58 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=0.3.0' --verify-diagnostics --split-input-file %s + +func.func @custom_call_v2_with_output_operand_alises(%arg0 : tensor) -> tensor { + // expected-error @+2 {{failed to downgrade vhlo.custom_call_v2, op has a non-empty output_operand_aliases attribute}} + // expected-error @+1 {{failed to legalize operation 'vhlo.custom_call_v2' that was explicitly marked illegal}} + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = false, + backend_config = "", + api_version = 1 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // expected-error @+2 {{failed to downgrade vhlo.collective_permute_v2, op has a non-empty channel_handle attribute}} + // expected-error @+1 {{failed to legalize operation 'vhlo.collective_permute_v2' that was explicitly marked illegal}} + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// ----- + +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // expected-error @+2 {{failed to downgrade vhlo.all_gather_v2, op has a non-empty use_global_device_ids attribute}} + // expected-error @+1 {{failed to legalize operation 'vhlo.all_gather_v2' that was explicitly marked illegal}} + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// ----- + +// This test emulates two things: +// 1. A file that is too old and no longer supported on consumer. +// 2. A file that is too new and not yet supported on consumer. +// More work should be done to improve this error message. +func.func @invalid_program_unknown_op(%arg0 : tensor) -> (tensor) { + // expected-error @+1 {{unregistered operation 'vhlo.unknown_op' found in dialect ('vhlo') that does not allow unknown operations}} + %0 = "vhlo.unknown_op"(%arg0) : (tensor) -> tensor + func.return +} diff --git a/stablehlo/tests/vhlo_to_version_invalid_target_empty.mlir b/stablehlo/tests/vhlo_to_version_invalid_target_empty.mlir new file mode 100644 index 0000000000..00df7be6fc --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_invalid_target_empty.mlir @@ -0,0 +1,2 @@ +// RUN: stablehlo-opt --vhlo-to-version --verify-diagnostics %s +// expected-error @-2 {{No target version specified. Specify target using: --vhlo-to-version='target=[targetVersion]'}} diff --git a/stablehlo/tests/vhlo_to_version_invalid_target_future.mlir b/stablehlo/tests/vhlo_to_version_invalid_target_future.mlir new file mode 100644 index 0000000000..b58f002193 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_invalid_target_future.mlir @@ -0,0 +1,2 @@ +// RUN: stablehlo-opt --vhlo-to-version='target=100.10.10' --verify-diagnostics %s +// expected-error @-2 {{target version 100.10.10 is greater than current version 0.4.0}} diff --git a/stablehlo/tests/vhlo_to_version_invalid_target_minimum.mlir b/stablehlo/tests/vhlo_to_version_invalid_target_minimum.mlir new file mode 100644 index 0000000000..3bd9646436 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_invalid_target_minimum.mlir @@ -0,0 +1,2 @@ +// RUN: stablehlo-opt --vhlo-to-version='target=0.0.0' --verify-diagnostics %s +// expected-error @-2 {{target version 0.0.0 is less than minimum supported 0.3.0}} diff --git a/stablehlo/tests/vhlo_to_version_invalid_target_not_version.mlir b/stablehlo/tests/vhlo_to_version_invalid_target_not_version.mlir new file mode 100644 index 0000000000..6a47021024 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_invalid_target_not_version.mlir @@ -0,0 +1,2 @@ +// RUN: stablehlo-opt --vhlo-to-version='target=x.y.z' --verify-diagnostics %s +// expected-error @-2 {{Invalid target version argument 'x.y.z'}} diff --git a/stablehlo/tests/vhlo_to_version_upgrade.mlir b/stablehlo/tests/vhlo_to_version_upgrade.mlir new file mode 100644 index 0000000000..4f3b08546a --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_upgrade.mlir @@ -0,0 +1,23 @@ +// RUN: stablehlo-opt --vhlo-to-version='target=0.4.0' %s | FileCheck %s +// RUN: stablehlo-opt --vhlo-to-version='target=current' %s | FileCheck %s + +// CHECK-LABEL: @all_gather_to_v2 +func.func @all_gather_to_v2(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK-NEXT: %0 = "vhlo.all_gather_v2"(%arg0) + %0 = "vhlo.all_gather"(%arg0) {all_gather_dim = 1 : i64, channel_handle = #vhlo.channel_handle, replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>} : (tensor<16x8xf32>) -> tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: @collective_permute_to_v2 +func.func @collective_permute_to_v2(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK-NEXT: %0 = "vhlo.collective_permute_v2"(%arg0) + %0 = "vhlo.collective_permute"(%arg0) {source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<16x8xf32>) -> tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: @custom_call_to_v2 +func.func @custom_call_to_v2(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: %0 = "vhlo.custom_call_v2"(%arg0) + %0 = "vhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "foo"} : (tensor<2xi1>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} diff --git a/stablehlo/tools/CMakeLists.txt b/stablehlo/tools/CMakeLists.txt index f0fd6e5627..20f4f60e36 100644 --- a/stablehlo/tools/CMakeLists.txt +++ b/stablehlo/tools/CMakeLists.txt @@ -27,6 +27,7 @@ set(LIBS MLIROptLib StablehloRegister StablehloTestUtils + StablehloPasses ) add_llvm_executable(stablehlo-opt StablehloOptMain.cpp) llvm_update_compile_flags(stablehlo-opt) diff --git a/stablehlo/tools/StablehloOptMain.cpp b/stablehlo/tools/StablehloOptMain.cpp index a097a2d26e..e5109262da 100644 --- a/stablehlo/tools/StablehloOptMain.cpp +++ b/stablehlo/tools/StablehloOptMain.cpp @@ -18,10 +18,12 @@ limitations under the License. #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "stablehlo/dialect/Register.h" #include "stablehlo/tests/TestUtils.h" +#include "stablehlo/transforms/Passes.h" int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::hlo::registerAllTestPasses(); + mlir::stablehlo::registerPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt new file mode 100644 index 0000000000..25a1da3723 --- /dev/null +++ b/stablehlo/transforms/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright 2022 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(PassesIncGen) + +add_mlir_dialect_library(StablehloTypeConversion + PARTIAL_SOURCES_INTENDED + TypeConversion.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + MLIRFuncDialect +) + +add_mlir_dialect_library(StablehloPasses + PARTIAL_SOURCES_INTENDED + LegalizeStablehloToVhlo.cpp + LegalizeVhloToStablehlo.cpp + VhloToVersion.cpp + + DEPENDS + PassesIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + MLIRQuantDialect + StablehloTypeConversion +) diff --git a/stablehlo/transforms/LegalizeStablehloToVhlo.cpp b/stablehlo/transforms/LegalizeStablehloToVhlo.cpp new file mode 100644 index 0000000000..b998a7caeb --- /dev/null +++ b/stablehlo/transforms/LegalizeStablehloToVhlo.cpp @@ -0,0 +1,227 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/VhloOps.h" +#include "stablehlo/transforms/MapStablehloToVhlo.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/TypeConversion.h" + +#define DEBUG_TYPE "compat-passes" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_STABLEHLOLEGALIZETOVHLOPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +#define RETURN_CONVERTED_ENUM_ATTR(Name) \ + auto stablehloValue = stablehlo::stringify##Name(attr.getValue()); \ + auto vhloValue = vhlo::symbolize##Name(stablehloValue); \ + if (!vhloValue.has_value()) return {}; \ + return vhlo::Name##Attr::get(attr.getContext(), vhloValue.value()) + +Attribute convertAttrToVhlo(Attribute stablehloAttr) { + // Handle StableHLO attributes. + // The logic that handles attributes from other dialects (e.g. builtin + // attributes) lives below. + if (auto attr = stablehloAttr.dyn_cast()) { + return vhlo::ChannelHandleAttr::get(attr.getContext(), attr.getHandle(), + attr.getType()); + } + if (auto attr = + stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(ComparisonDirection); + } + if (auto attr = stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(ComparisonType); + } + if (auto attr = + stablehloAttr.dyn_cast()) { + return vhlo::ConvDimensionNumbersAttr::get( + attr.getContext(), attr.getInputBatchDimension(), + attr.getInputFeatureDimension(), attr.getInputSpatialDimensions(), + attr.getKernelInputFeatureDimension(), + attr.getKernelOutputFeatureDimension(), + attr.getKernelSpatialDimensions(), attr.getOutputBatchDimension(), + attr.getOutputFeatureDimension(), attr.getOutputSpatialDimensions()); + } + if (auto attr = + stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(CustomCallApiVersion); + } + if (auto attr = + stablehloAttr.dyn_cast()) { + return vhlo::DotDimensionNumbersAttr::get( + attr.getContext(), attr.getLhsBatchingDimensions(), + attr.getRhsBatchingDimensions(), attr.getLhsContractingDimensions(), + attr.getRhsContractingDimensions()); + } + if (auto attr = stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(FftType); + } + if (auto attr = + stablehloAttr.dyn_cast()) { + return vhlo::GatherDimensionNumbersAttr::get( + attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(), + attr.getStartIndexMap(), attr.getIndexVectorDim()); + } + if (auto attr = stablehloAttr.dyn_cast()) { + return vhlo::OutputOperandAliasAttr::get( + attr.getContext(), attr.getOutputTupleIndices(), attr.getOperandIndex(), + attr.getOperandTupleIndices()); + } + if (auto attr = stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(Precision); + } + if (auto attr = stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(RngAlgorithm); + } + if (auto attr = stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(RngDistribution); + } + if (auto attr = + stablehloAttr.dyn_cast()) { + return vhlo::ScatterDimensionNumbersAttr::get( + attr.getContext(), attr.getUpdateWindowDims(), + attr.getInsertedWindowDims(), attr.getScatterDimsToOperandDims(), + attr.getIndexVectorDim()); + } + if (auto attr = stablehloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(Transpose); + } + if (stablehloAttr.getDialect().getNamespace() == + stablehlo::StablehloDialect::getDialectNamespace()) { + // All StableHLO attributes must have counterparts in VHLO. + return {}; + } + + // Handle non-StableHLO attributes. + // If an attribute is not defined in StableHLO, then it is unchanged, + // with the exception of ArrayAttr which is converted recursively. + // This will change once we fork necessary upstream types to VHLO. + if (auto stablehloAttrs = stablehloAttr.dyn_cast()) { + SmallVector vhloAttrs; + for (auto stablehloAttr : stablehloAttrs) { + auto vhloAttr = convertAttrToVhlo(stablehloAttr); + if (!vhloAttr) return {}; + vhloAttrs.push_back(vhloAttr); + } + return ArrayAttr::get(stablehloAttrs.getContext(), vhloAttrs); + } + return stablehloAttr; +} + +#undef RETURN_CONVERTED_ENUM_ATTR + +template +class StablehloToVhloOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + StablehloOpTy stablehloOp, typename StablehloOpTy::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + SmallVector vhloTypes; + LLVM_DEBUG(llvm::dbgs() << "Converting types:\n"); + if (failed(this->getTypeConverter()->convertTypes( + stablehloOp->getResultTypes(), vhloTypes))) { + LLVM_DEBUG(llvm::dbgs() << "Failed type conversion\n"); + return failure(); + } + + // These operands have already been converted to VHLO by + // the dialect conversion infrastructure. + ValueRange vhloOperands = adaptor.getOperands(); + + SmallVector vhloAttrs; + for (NamedAttribute stablehloAttr : stablehloOp->getAttrs()) { + auto vhloAttr = convertAttrToVhlo(stablehloAttr.getValue()); + if (!vhloAttr) return failure(); + vhloAttrs.push_back({stablehloAttr.getName(), vhloAttr}); + } + + // Convert the vhlo operation to a StableHLO equivalent. + // This can almost be done in a generic fashion, except for + // vhlo.case that uses a variadic number of regions which means an + // additional argument for the generic builder. + StablehloToVhloOp vhloOp; + if constexpr (std::is_same::value) { + vhloOp = rewriter.replaceOpWithNewOp( + stablehloOp, vhloTypes, vhloOperands, vhloAttrs, + stablehloOp.getBranches().size()); + } else { + vhloOp = rewriter.replaceOpWithNewOp>( + stablehloOp, vhloTypes, vhloOperands, vhloAttrs); + } + + for (auto [stablehloRegion, vhloRegion] : + llvm::zip(stablehloOp->getRegions(), vhloOp->getRegions())) { + rewriter.inlineRegionBefore(stablehloRegion, vhloRegion, + vhloRegion.end()); + } + return success(); + } +}; + +template +void populateStablehloToVhloPatterns(RewritePatternSet* patterns, + TypeConverter* converter, + MLIRContext* context) { + patterns->add...>(*converter, + context); +} + +} // namespace + +////////////////////////// +/// StableHLO --> VHLO /// +////////////////////////// + +struct StablehloLegalizeToVhloPass + : public impl::StablehloLegalizeToVhloPassBase< + StablehloLegalizeToVhloPass> { + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + + vhlo::StablehloToVhloTypeConverter converter; + RewritePatternSet patterns(&getContext()); + stablehlo::populateStablehloToVhloPatterns(&patterns, &converter, + &getContext()); + registerFuncOpsForTypeConversion(target, patterns, converter); + + // StableHLO is a subset of VHLO. + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + LLVM_DEBUG(llvm::dbgs() << "Failed partial conversion\n"); + return signalPassFailure(); + } + } +}; + +void populateStablehloToVhloPatterns(RewritePatternSet* patterns, + TypeConverter* converter, + MLIRContext* context) { + populateStablehloToVhloPatterns< +#define GET_OP_LIST +#include "stablehlo/dialect/StablehloOps.cpp.inc" + >(patterns, converter, context); +} +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/LegalizeVhloToStablehlo.cpp b/stablehlo/transforms/LegalizeVhloToStablehlo.cpp new file mode 100644 index 0000000000..3bedb8d547 --- /dev/null +++ b/stablehlo/transforms/LegalizeVhloToStablehlo.cpp @@ -0,0 +1,216 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/VhloOps.h" +#include "stablehlo/transforms/MapStablehloToVhlo.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/TypeConversion.h" + +#define DEBUG_TYPE "compat-passes" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_VHLOLEGALIZETOSTABLEHLOPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +////////////////////////// +/// VHLO --> StableHLO /// +////////////////////////// +#define RETURN_CONVERTED_ENUM_ATTR(Name) \ + auto vhloValue = vhlo::stringify##Name(attr.getValue()); \ + auto stablehloValue = stablehlo::symbolize##Name(vhloValue); \ + if (!stablehloValue.has_value()) return {}; \ + return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value()) + +Attribute convertAttrToStablehlo(Attribute vhloAttr) { + LLVM_DEBUG(llvm::dbgs() << "Converting " << vhloAttr); + if (auto attr = vhloAttr.dyn_cast()) { + return stablehlo::ChannelHandleAttr::get(attr.getContext(), + attr.getHandle(), attr.getType()); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(ComparisonDirection); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(ComparisonType); + } + if (auto attr = vhloAttr.dyn_cast()) { + return stablehlo::ConvDimensionNumbersAttr::get( + attr.getContext(), attr.getInputBatchDimension(), + attr.getInputFeatureDimension(), attr.getInputSpatialDimensions(), + attr.getKernelInputFeatureDimension(), + attr.getKernelOutputFeatureDimension(), + attr.getKernelSpatialDimensions(), attr.getOutputBatchDimension(), + attr.getOutputFeatureDimension(), attr.getOutputSpatialDimensions()); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(CustomCallApiVersion); + } + if (auto attr = vhloAttr.dyn_cast()) { + return stablehlo::DotDimensionNumbersAttr::get( + attr.getContext(), attr.getLhsBatchingDimensions(), + attr.getRhsBatchingDimensions(), attr.getLhsContractingDimensions(), + attr.getRhsContractingDimensions()); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(FftType); + } + if (auto attr = vhloAttr.dyn_cast()) { + return stablehlo::GatherDimensionNumbersAttr::get( + attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(), + attr.getStartIndexMap(), attr.getIndexVectorDim()); + } + if (auto attr = vhloAttr.dyn_cast()) { + return stablehlo::OutputOperandAliasAttr::get( + attr.getContext(), attr.getOutputTupleIndices(), attr.getOperandIndex(), + attr.getOperandTupleIndices()); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(Precision); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(RngAlgorithm); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(RngDistribution); + } + if (auto attr = vhloAttr.dyn_cast()) { + return stablehlo::ScatterDimensionNumbersAttr::get( + attr.getContext(), attr.getUpdateWindowDims(), + attr.getInsertedWindowDims(), attr.getScatterDimsToOperandDims(), + attr.getIndexVectorDim()); + } + if (auto attr = vhloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(Transpose); + } + if (vhloAttr.getDialect().getNamespace() == + vhlo::VhloDialect::getDialectNamespace()) { + // All VHLO attributes must have counterparts in StableHLO. + return {}; + } + + // Handle non-VHLO attributes. + // If an attribute is not defined in vhlo, then it is unchanged, + // with the exception of ArrayAttr which is converted recursively. + // This will change once we fork necessary upstream types to VHLO. + if (auto vhloAttrs = vhloAttr.dyn_cast()) { + SmallVector stablehloAttrs; + for (auto vhloAttr : vhloAttrs) { + auto stablehloAttr = convertAttrToStablehlo(vhloAttr); + if (!stablehloAttr) return {}; + stablehloAttrs.push_back(stablehloAttr); + } + return ArrayAttr::get(vhloAttrs.getContext(), stablehloAttrs); + } + return vhloAttr; +} + +#undef RETURN_CONVERTED_ENUM_ATTR + +struct VhloLegalizeToStablehloPass + : public impl::VhloLegalizeToStablehloPassBase< + VhloLegalizeToStablehloPass> { + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + + vhlo::VhloToStablehloTypeConverter converter; + RewritePatternSet patterns(&getContext()); + stablehlo::populateVhloToStablehloPatterns(&patterns, &converter, + &getContext()); + registerFuncOpsForTypeConversion(target, patterns, converter); + + // VHLO should always be convertible to StableHLO if upgraded. + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +template +class VhloToStablehloOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + VhloOpTy vhloOp, typename VhloOpTy::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + SmallVector stablehloTypes; + if (failed(this->getTypeConverter()->convertTypes(vhloOp->getResultTypes(), + stablehloTypes))) + return failure(); + + // These operands have already been converted to StableHLO by + // the dialect conversion infrastructure. + ValueRange stablehloOperands = adaptor.getOperands(); + + SmallVector stablehloAttrs; + for (NamedAttribute vhloAttr : vhloOp->getAttrs()) { + auto stablehloAttr = convertAttrToStablehlo(vhloAttr.getValue()); + if (!stablehloAttr) return failure(); + stablehloAttrs.push_back({vhloAttr.getName(), stablehloAttr}); + } + + // Convert the vhlo operation to a StableHLO equivalent. + // This can almost be done in a generic fashion, except for + // vhlo.case that uses a variadic number of regions which means an + // additional argument for the generic builder. + VhloToStablehloOp stablehloOp; + if constexpr (std::is_same::value) { + stablehloOp = rewriter.replaceOpWithNewOp( + vhloOp, stablehloTypes, stablehloOperands, stablehloAttrs, + vhloOp.getBranches().size()); + } else { + stablehloOp = rewriter.replaceOpWithNewOp>( + vhloOp, stablehloTypes, stablehloOperands, stablehloAttrs); + } + + for (auto [vhloRegion, stablehloRegion] : + llvm::zip(vhloOp->getRegions(), stablehloOp->getRegions())) { + rewriter.inlineRegionBefore(vhloRegion, stablehloRegion, + stablehloRegion.end()); + } + return success(); + } +}; + +template +void populateVhloToStablehloPatterns(RewritePatternSet* patterns, + TypeConverter* converter, + MLIRContext* context) { + patterns + ->add>...>( + *converter, context); +} + +} // namespace + +void populateVhloToStablehloPatterns(RewritePatternSet* patterns, + TypeConverter* converter, + MLIRContext* context) { + populateVhloToStablehloPatterns< +#define GET_OP_LIST +#include "stablehlo/dialect/StablehloOps.cpp.inc" + >(patterns, converter, context); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/MapStablehloToVhlo.h b/stablehlo/transforms/MapStablehloToVhlo.h new file mode 100644 index 0000000000..6f1762c484 --- /dev/null +++ b/stablehlo/transforms/MapStablehloToVhlo.h @@ -0,0 +1,171 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_TRANSFORMS_MAPSTABLEHLOTOVHLO_H +#define STABLEHLO_TRANSFORMS_MAPSTABLEHLOTOVHLO_H + +#include + +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/VhloOps.h" + +namespace mlir { +namespace stablehlo { + +template +struct VhloToStablehloOpImpl { + using Type = std::false_type; +}; +template +using VhloToStablehloOp = typename VhloToStablehloOpImpl::Type; + +template +struct StablehloToVhloOpImpl { + using Type = std::false_type; +}; +template +using StablehloToVhloOp = typename StablehloToVhloOpImpl::Type; + +#define MAP_STABLEHLO_TO_VHLO(OpName, OpVer) \ + template <> \ + struct StablehloToVhloOpImpl { \ + using Type = vhlo::OpName##OpVer; \ + }; \ + template <> \ + struct VhloToStablehloOpImpl { \ + using Type = stablehlo::OpName; \ + }; + +MAP_STABLEHLO_TO_VHLO(AbsOp, V1) +MAP_STABLEHLO_TO_VHLO(AddOp, V1) +MAP_STABLEHLO_TO_VHLO(AfterAllOp, V1) +MAP_STABLEHLO_TO_VHLO(AllGatherOp, V2) +MAP_STABLEHLO_TO_VHLO(AllReduceOp, V1) +MAP_STABLEHLO_TO_VHLO(AllToAllOp, V1) +MAP_STABLEHLO_TO_VHLO(AndOp, V1) +MAP_STABLEHLO_TO_VHLO(Atan2Op, V1) +MAP_STABLEHLO_TO_VHLO(BatchNormGradOp, V1) +MAP_STABLEHLO_TO_VHLO(BatchNormInferenceOp, V1) +MAP_STABLEHLO_TO_VHLO(BatchNormTrainingOp, V1) +MAP_STABLEHLO_TO_VHLO(BitcastConvertOp, V1) +MAP_STABLEHLO_TO_VHLO(BroadcastInDimOp, V1) +MAP_STABLEHLO_TO_VHLO(BroadcastOp, V1) +MAP_STABLEHLO_TO_VHLO(CaseOp, V1) +MAP_STABLEHLO_TO_VHLO(CbrtOp, V1) +MAP_STABLEHLO_TO_VHLO(CeilOp, V1) +MAP_STABLEHLO_TO_VHLO(CholeskyOp, V1) +MAP_STABLEHLO_TO_VHLO(ClampOp, V1) +MAP_STABLEHLO_TO_VHLO(ClzOp, V1) +MAP_STABLEHLO_TO_VHLO(CollectivePermuteOp, V2) +MAP_STABLEHLO_TO_VHLO(CompareOp, V1) +MAP_STABLEHLO_TO_VHLO(ComplexOp, V1) +MAP_STABLEHLO_TO_VHLO(ComputeReshapeShapeOp, V1) +MAP_STABLEHLO_TO_VHLO(ConcatenateOp, V1) +MAP_STABLEHLO_TO_VHLO(ConstantOp, V1) +MAP_STABLEHLO_TO_VHLO(ConvertOp, V1) +MAP_STABLEHLO_TO_VHLO(ConvolutionOp, V1) +MAP_STABLEHLO_TO_VHLO(CosineOp, V1) +MAP_STABLEHLO_TO_VHLO(CreateTokenOp, V1) +MAP_STABLEHLO_TO_VHLO(CrossReplicaSumOp, V1) +MAP_STABLEHLO_TO_VHLO(CstrReshapableOp, V1) +MAP_STABLEHLO_TO_VHLO(CustomCallOp, V2) +MAP_STABLEHLO_TO_VHLO(DivOp, V1) +MAP_STABLEHLO_TO_VHLO(DotGeneralOp, V1) +MAP_STABLEHLO_TO_VHLO(DotOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicBroadcastInDimOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicConvOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicGatherOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicIotaOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicPadOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicReshapeOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicSliceOp, V1) +MAP_STABLEHLO_TO_VHLO(DynamicUpdateSliceOp, V1) +MAP_STABLEHLO_TO_VHLO(EinsumOp, V1) +MAP_STABLEHLO_TO_VHLO(Expm1Op, V1) +MAP_STABLEHLO_TO_VHLO(ExpOp, V1) +MAP_STABLEHLO_TO_VHLO(FftOp, V1) +MAP_STABLEHLO_TO_VHLO(FloorOp, V1) +MAP_STABLEHLO_TO_VHLO(GatherOp, V1) +MAP_STABLEHLO_TO_VHLO(GetDimensionSizeOp, V1) +MAP_STABLEHLO_TO_VHLO(GetTupleElementOp, V1) +MAP_STABLEHLO_TO_VHLO(IfOp, V1) +MAP_STABLEHLO_TO_VHLO(ImagOp, V1) +MAP_STABLEHLO_TO_VHLO(InfeedOp, V1) +MAP_STABLEHLO_TO_VHLO(IotaOp, V1) +MAP_STABLEHLO_TO_VHLO(IsFiniteOp, V1) +MAP_STABLEHLO_TO_VHLO(Log1pOp, V1) +MAP_STABLEHLO_TO_VHLO(LogisticOp, V1) +MAP_STABLEHLO_TO_VHLO(LogOp, V1) +MAP_STABLEHLO_TO_VHLO(MapOp, V1) +MAP_STABLEHLO_TO_VHLO(MaxOp, V1) +MAP_STABLEHLO_TO_VHLO(MinOp, V1) +MAP_STABLEHLO_TO_VHLO(MulOp, V1) +MAP_STABLEHLO_TO_VHLO(NegOp, V1) +MAP_STABLEHLO_TO_VHLO(NotOp, V1) +MAP_STABLEHLO_TO_VHLO(OptimizationBarrierOp, V1) +MAP_STABLEHLO_TO_VHLO(OrOp, V1) +MAP_STABLEHLO_TO_VHLO(OutfeedOp, V1) +MAP_STABLEHLO_TO_VHLO(PadOp, V1) +MAP_STABLEHLO_TO_VHLO(PopulationCountOp, V1) +MAP_STABLEHLO_TO_VHLO(PowOp, V1) +MAP_STABLEHLO_TO_VHLO(RealDynamicSliceOp, V1) +MAP_STABLEHLO_TO_VHLO(RealOp, V1) +MAP_STABLEHLO_TO_VHLO(RecvOp, V1) +MAP_STABLEHLO_TO_VHLO(ReduceOp, V1) +MAP_STABLEHLO_TO_VHLO(ReducePrecisionOp, V1) +MAP_STABLEHLO_TO_VHLO(ReduceScatterOp, V1) +MAP_STABLEHLO_TO_VHLO(ReduceWindowOp, V1) +MAP_STABLEHLO_TO_VHLO(RemOp, V1) +MAP_STABLEHLO_TO_VHLO(ReplicaIdOp, V1) +MAP_STABLEHLO_TO_VHLO(ReshapeOp, V1) +MAP_STABLEHLO_TO_VHLO(ReturnOp, V1) +MAP_STABLEHLO_TO_VHLO(ReverseOp, V1) +MAP_STABLEHLO_TO_VHLO(RngBitGeneratorOp, V1) +MAP_STABLEHLO_TO_VHLO(RngOp, V1) +MAP_STABLEHLO_TO_VHLO(RoundOp, V1) +MAP_STABLEHLO_TO_VHLO(RoundNearestEvenOp, V1) +MAP_STABLEHLO_TO_VHLO(RsqrtOp, V1) +MAP_STABLEHLO_TO_VHLO(ScatterOp, V1) +MAP_STABLEHLO_TO_VHLO(SelectAndScatterOp, V1) +MAP_STABLEHLO_TO_VHLO(SelectOp, V1) +MAP_STABLEHLO_TO_VHLO(SendOp, V1) +MAP_STABLEHLO_TO_VHLO(SetDimensionSizeOp, V1) +MAP_STABLEHLO_TO_VHLO(ShiftLeftOp, V1) +MAP_STABLEHLO_TO_VHLO(ShiftRightArithmeticOp, V1) +MAP_STABLEHLO_TO_VHLO(ShiftRightLogicalOp, V1) +MAP_STABLEHLO_TO_VHLO(SignOp, V1) +MAP_STABLEHLO_TO_VHLO(SineOp, V1) +MAP_STABLEHLO_TO_VHLO(SliceOp, V1) +MAP_STABLEHLO_TO_VHLO(SortOp, V1) +MAP_STABLEHLO_TO_VHLO(SqrtOp, V1) +MAP_STABLEHLO_TO_VHLO(SubtractOp, V1) +MAP_STABLEHLO_TO_VHLO(TanhOp, V1) +MAP_STABLEHLO_TO_VHLO(TorchIndexSelectOp, V1) +MAP_STABLEHLO_TO_VHLO(TraceOp, V1) +MAP_STABLEHLO_TO_VHLO(TransposeOp, V1) +MAP_STABLEHLO_TO_VHLO(TriangularSolveOp, V1) +MAP_STABLEHLO_TO_VHLO(TupleOp, V1) +MAP_STABLEHLO_TO_VHLO(UnaryEinsumOp, V1) +MAP_STABLEHLO_TO_VHLO(UniformDequantizeOp, V1) +MAP_STABLEHLO_TO_VHLO(UniformQuantizeOp, V1) +MAP_STABLEHLO_TO_VHLO(WhileOp, V1) +MAP_STABLEHLO_TO_VHLO(XorOp, V1) + +#undef MAP_STABLEHLO_TO_VHLO +#undef MAP_STABLEHLO_TO_VHLO_V0 + +} // namespace stablehlo +} // namespace mlir + +#endif // STABLEHLO_TRANSFORMS_MAPSTABLEHLOTOVHLO_H diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h new file mode 100644 index 0000000000..a6a77e53b3 --- /dev/null +++ b/stablehlo/transforms/Passes.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_TRANSFORMS_PASSES_H +#define STABLEHLO_TRANSFORMS_PASSES_H + +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace stablehlo { +#define GEN_PASS_DECL_STABLEHLOLEGALIZETOVHLOPASS +#define GEN_PASS_DECL_VHLOLEGALIZETOSTABLEHLOPASS +#define GEN_PASS_DECL_VHLOUPGRADEPASS +#define GEN_PASS_DECL_VHLODOWNGRADEPASS +#define GEN_PASS_DECL_VHLOTOVERSIONPASS +#define GEN_PASS_REGISTRATION +#include "stablehlo/transforms/Passes.h.inc" + +// Populates StableHLO ops to VHLO ops rewriting patterns. +void populateStablehloToVhloPatterns(RewritePatternSet *patterns, + TypeConverter *converter, + MLIRContext *context); + +// Populates VHLO ops to StableHLO ops rewriting patterns. +void populateVhloToStablehloPatterns(RewritePatternSet *patterns, + TypeConverter *converter, + MLIRContext *context); + +// Populates VHLO downgrade rewriting patterns. +void populateVhloToVersionPatterns(RewritePatternSet *patterns, + TypeConverter *converter, + MLIRContext *contexts); +} // namespace stablehlo +} // namespace mlir + +#endif // STABLEHLO_DIALECT_VHLO_OPS_H diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td new file mode 100644 index 0000000000..6e5b1e9ccb --- /dev/null +++ b/stablehlo/transforms/Passes.td @@ -0,0 +1,34 @@ +/* Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { + let summary = "Legalize StableHLO to VHLO."; + let dependentDialects = ["mlir::vhlo::VhloDialect"]; +} + +def VhloLegalizeToStablehloPass : Pass<"vhlo-legalize-to-stablehlo", "ModuleOp"> { + let summary = "Legalize VHLO to StableHLO."; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def VhloToVersionPass : Pass<"vhlo-to-version"> { + let summary = "Convert between versions of VHLO."; + let options = [ + Option<"targetVersion", "target", "std::string", "", + "The target version. Must be a version of the form #.#.# or 'current'.">, + ]; +} diff --git a/stablehlo/transforms/TypeConversion.cpp b/stablehlo/transforms/TypeConversion.cpp new file mode 100644 index 0000000000..a1a06a9b68 --- /dev/null +++ b/stablehlo/transforms/TypeConversion.cpp @@ -0,0 +1,42 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/transforms/TypeConversion.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" + +namespace mlir { +namespace vhlo { + +void registerFuncOpsForTypeConversion(ConversionTarget& target, + RewritePatternSet& patterns, + TypeConverter& converter) { + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return converter.isSignatureLegal(op.getCalleeType()); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return converter.isLegal(op.getOperandTypes()); + }); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); +} + +} // namespace vhlo +} // namespace mlir diff --git a/stablehlo/transforms/TypeConversion.h b/stablehlo/transforms/TypeConversion.h new file mode 100644 index 0000000000..05883d8901 --- /dev/null +++ b/stablehlo/transforms/TypeConversion.h @@ -0,0 +1,136 @@ +/* Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_TRANSFORMS_TYPECONVERSION_H +#define STABLEHLO_TRANSFORMS_TYPECONVERSION_H + +#include "llvm/Support/Debug.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/VhloOps.h" + +#define DEBUG_TYPE "compat-passes" + +namespace mlir { +namespace vhlo { + +class VersionedTypeConverterBase : public TypeConverter { + public: + VersionedTypeConverterBase() : TypeConverter() { + addConversion([](Type t) -> Type { return t; }); + addConversion([&](TupleType type) -> Type { + SmallVector convertedTypes; + if (failed(convertTypes(type.getTypes(), convertedTypes))) return {}; + return TupleType::get(type.getContext(), convertedTypes); + }); + addConversion([&](RankedTensorType type) -> Type { + auto encoding = type.getEncoding(); + if (!encoding) return type; + if (isSourceDialect(encoding.getDialect())) { + auto convertedEncoding = convertEncoding(encoding); + if (!convertedEncoding) return {}; + return RankedTensorType::get(type.getShape(), type.getElementType(), + convertedEncoding); + } + return type; + }); + }; + + virtual ~VersionedTypeConverterBase() = default; + + // Checks whether the given dialect is the source dialect of the type + // conversion (e.g. StableHLO for StablehloToVhloTypeConverter). + virtual bool isSourceDialect(Dialect& dialect) = 0; + + virtual Attribute convertEncoding(Attribute attr) = 0; +}; + +class StablehloToVhloTypeConverter : public VersionedTypeConverterBase { + public: + StablehloToVhloTypeConverter() : VersionedTypeConverterBase() { + addConversion([](stablehlo::TokenType token) -> Type { + LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n"); + return TokenType::get(token.getContext()); + }); + } + + bool isSourceDialect(Dialect& dialect) final { + return dialect.getNamespace() == + stablehlo::StablehloDialect::getDialectNamespace(); + } + + Attribute convertEncoding(Attribute attr) final { + LLVM_DEBUG(llvm::dbgs() << "Converting encoding.\n"); + LLVM_DEBUG(llvm::dbgs() << attr); + if (auto stablehloAttr = + attr.dyn_cast_or_null()) { + LLVM_DEBUG(llvm::dbgs() << "Matched StableHLO encoding.\n"); + return vhlo::TypeExtensionsAttr::get(stablehloAttr.getContext(), + stablehloAttr.getBounds()); + } + // All encodings should be supported. + return {}; + } +}; + +class VhloToStablehloTypeConverter : public VersionedTypeConverterBase { + public: + VhloToStablehloTypeConverter() : VersionedTypeConverterBase() { + addConversion([](vhlo::TokenType token) -> Type { + LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n"); + return stablehlo::TokenType::get(token.getContext()); + }); + } + + bool isSourceDialect(Dialect& dialect) final { + return dialect.getNamespace() == vhlo::VhloDialect::getDialectNamespace(); + } + + Attribute convertEncoding(Attribute attr) final { + if (auto vhloAttr = attr.dyn_cast_or_null()) { + return stablehlo::TypeExtensionsAttr::get(vhloAttr.getContext(), + vhloAttr.getBounds()); + } + // All encodings should be supported. + return attr; + } +}; + +class VhloToVersionConverter : public VersionedTypeConverterBase { + public: + VhloToVersionConverter() : VersionedTypeConverterBase() { + addConversion([](stablehlo::TokenType token) -> Type { + LLVM_DEBUG(llvm::dbgs() << "Converting TokenType\n"); + return TokenType::get(token.getContext()); + }); + } + + bool isSourceDialect(Dialect& dialect) final { + return dialect.getNamespace() == vhlo::VhloDialect::getDialectNamespace(); + } + + Attribute convertEncoding(Attribute attr) final { return attr; } +}; + +// Complements conversion patterns with boilerplate that makes sure `func.func`, +// `func.call` and `func.return` ops which involve illegal types get converted +// to use legal types. +void registerFuncOpsForTypeConversion(ConversionTarget& target, + RewritePatternSet& patterns, + TypeConverter& converter); +} // namespace vhlo +} // namespace mlir + +#endif // STABLEHLO_TRANSFORMS_MAPSTABLEHLOTOVHLO_H diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp new file mode 100644 index 0000000000..c65b8acece --- /dev/null +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -0,0 +1,261 @@ +/* Copyright 2022 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/Version.h" +#include "stablehlo/dialect/VhloOps.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/TypeConversion.h" + +#define DEBUG_TYPE "compat-passes" + +namespace mlir { +namespace stablehlo { +#define GEN_PASS_DEF_VHLOTOVERSIONPASS +#include "stablehlo/transforms/Passes.h.inc" +} // namespace stablehlo + +/////////////////////// +/// VHLO To Version /// +/////////////////////// +namespace vhlo { +namespace { + +FailureOr parseTargetVersion(llvm::StringRef versionRef) { + if (versionRef == "current") { + return VhloDialect::getCurrentVersion(); + } + return Version::fromString(versionRef); +} + +FailureOr validateTargetVersion(llvm::StringRef versionRef, + Operation* op) { + auto failOrVersion = parseTargetVersion(versionRef); + if (failed(failOrVersion)) { + if (versionRef.empty()) { + return emitError(op->getLoc()) + << "No target version specified. Specify target using: " + "--vhlo-to-version='target=[targetVersion]'\n" + << "Target version must be of the form #.#.# or 'current'."; + } + return emitError(op->getLoc()) + << "Invalid target version argument '" << versionRef << "'\n" + << "Target version must be of the form #.#.# or 'current'."; + } + + Version targetVersion = *failOrVersion; + if (targetVersion < VhloDialect::getMinimumVersion()) { + return emitError(op->getLoc()) << "target version " << targetVersion + << " is less than minimum supported " + << VhloDialect::getMinimumVersion(); + } + if (VhloDialect::getCurrentVersion() < targetVersion) { + return emitError(op->getLoc()) << "target version " << targetVersion + << " is greater than current version " + << VhloDialect::getCurrentVersion(); + } + return targetVersion; +} + +using stablehlo::VhloToVersionPassOptions; +using stablehlo::impl::VhloToVersionPassBase; +struct VhloToVersionPass : public VhloToVersionPassBase { + VhloToVersionPass() : VhloToVersionPassBase() {} + VhloToVersionPass(VhloToVersionPassOptions const& opts) + : VhloToVersionPassBase(opts) {} + + void runOnOperation() override { + ConversionTarget target(getContext()); + + // Validate version number + auto failOrVersion = validateTargetVersion(targetVersion, getOperation()); + if (failed(failOrVersion)) { + return signalPassFailure(); + } + Version targetVersion = *failOrVersion; + + // An op is legal if the target version is in the ops `[min, max]` + // supported version range. + // Example: + // CustomCallV1 0.0.0 -> 0.0.x + // CustomCallV2 0.1.0 -> 0.4.x + // CustomCallV3 0.5.0 -> Curr + // Target Curr (0.5.0): + // V3 legal { Curr in [0.5, Curr] } + // V2 illegal { Curr !in [0.1, 0.4] } + // V1 illegal { Curr !in [0.0, 0.0] } + // Target 0.4.0: + // V3 illegal { 0.4 !in [0.5, Curr] } + // V2 legal { 0.4 in [0.1, 0.4] } + // V1 illegal { 0.4 !in [0.0, 0.0] } + // Target 0.0.0: + // V3 illegal { 0.0 !in [0.5, Curr] } + // V2 illegal { 0.1 !in [0.1, 0.4] } + // V1 legal { 0.0 in [0.0, 0.1] } + target.addDynamicallyLegalDialect( + [&targetVersion](Operation* op) { + if (auto interface = dyn_cast(op)) { + return (interface.getMinVersion() <= targetVersion && + targetVersion <= interface.getMaxVersion()); + } + return false; + }); + + vhlo::VhloToVersionConverter converter; + RewritePatternSet patterns(&getContext()); + stablehlo::populateVhloToVersionPatterns(&patterns, &converter, + &getContext()); + + // Conversions within VHLO may fail if new features or ops are used. + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +//////////////////////////////////////////// +/// Upgrade and Downgrade Infrastructure /// +//////////////////////////////////////////// + +LogicalResult emitDowngradeError(Operation* op, llvm::StringRef message) { + return op->emitError("failed to downgrade ") + << op->getName() << ", " << message; +} + +template +struct VersionConversionPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This method allows subclasses to add or remove attributes if needed. + // Can also fail if an op uses a feature that cannot be represented + // in previous versions of the opset. + virtual LogicalResult prepareOpForConversion(SourceOp op) const = 0; + + LogicalResult matchAndRewrite( + SourceOp op, typename SourceOp::Adaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + if (failed(prepareOpForConversion(op))) { + return failure(); + } + auto newOp = rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands(), op->getAttrs()); + for (auto [oldRegion, newRegion] : + llvm::zip(op->getRegions(), newOp->getRegions())) { + rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.end()); + } + return success(); + } +}; + +///////////////////////////////////////// +/// Upgrade and Downgrade Definitions /// +///////////////////////////////////////// + +// vhlo.custom_call --> vhlo.custom_call_v2 +struct CustomCallOpV1ToV2 + : public VersionConversionPattern { + using VersionConversionPattern::VersionConversionPattern; + LogicalResult prepareOpForConversion(CustomCallOpV1) const final { + return success(); + } +}; + +// vhlo.custom_call_v2 --> vhlo.custom_call +struct CustomCallOpV2ToV1 + : public VersionConversionPattern { + using VersionConversionPattern::VersionConversionPattern; + LogicalResult prepareOpForConversion(CustomCallOpV2 op) const final { + if (!op.getOutputOperandAliases().empty()) { + return emitDowngradeError( + op, "op has a non-empty output_operand_aliases attribute"); + } + if (op->hasAttr("output_operand_aliases")) + op->removeAttr("output_operand_aliases"); + return success(); + } +}; + +// vhlo.collective_permute --> vhlo.collective_permute_v2 +struct CollectivePermuteOpV1ToV2 + : public VersionConversionPattern { + using VersionConversionPattern::VersionConversionPattern; + LogicalResult prepareOpForConversion(CollectivePermuteOpV1) const final { + return success(); + } +}; + +// vhlo.collective_permute_v2 --> vhlo.collective_permute +struct CollectivePermuteOpV2ToV1 + : public VersionConversionPattern { + using VersionConversionPattern::VersionConversionPattern; + LogicalResult prepareOpForConversion(CollectivePermuteOpV2 op) const final { + if (op.getChannelHandle().has_value()) { + return emitDowngradeError(op, + "op has a non-empty channel_handle attribute"); + } + return success(); + } +}; + +// vhlo.all_gather--> vhlo.all_gather_v2 +struct AllGatherOpV1ToV2 + : public VersionConversionPattern { + using VersionConversionPattern::VersionConversionPattern; + LogicalResult prepareOpForConversion(AllGatherOpV1) const final { + return success(); + } +}; + +// vhlo.all_gather_v2 --> vhlo.all_gather +struct AllGatherOpV2ToV1 + : public VersionConversionPattern { + using VersionConversionPattern::VersionConversionPattern; + LogicalResult prepareOpForConversion(AllGatherOpV2 op) const final { + if (op.getUseGlobalDeviceIdsAttr()) { + return emitDowngradeError( + op, "op has a non-empty use_global_device_ids attribute"); + } + return success(); + } +}; + +} // namespace +} // namespace vhlo + +namespace stablehlo { +void populateVhloToVersionPatterns(RewritePatternSet* patterns, + TypeConverter* converter, + MLIRContext* context) { + patterns->add(*converter, context); + patterns->add(*converter, context); + patterns->add(*converter, context); + patterns->add(*converter, context); + patterns->add(*converter, context); + patterns->add(*converter, context); +} + +} // namespace stablehlo +} // namespace mlir