Skip to content

Commit

Permalink
StableHLO Versioned Dialect and Compatibility Passes (#278)
Browse files Browse the repository at this point in the history
Add StableHLO compatibility dialect and passes for reading and writing with forward/backward compatibility guarantees.

**Note: This is still a prototype implementation and should not be used in production until RFCs have been approved and types have been forked.**

```bash
stablehlo-opt [flags] file.mlir
    --stablehlo-legalize-to-vhlo 
    --vhlo-to-version='target=[version]'    Translate versioned dialect to target version (current, 0.3.0, ...)
    --vhlo-legalize-to-stablehlo 
```

Change description:
- Introduce VHLO, the Versioned StableHLO Dialect.
  + This dialect is a shallow copy of StableHLO's in-memory layout. It does not include verifiers or constraints.
  + Once an op is added to VHLO it must remain unchanged so that it can be guaranteed that a VHLO op is identical across versions.
  + The first version of VHLO is `0.3.0`.
- Conversion passes for compatibility
  + StableHLO <--> VHLO legalizations. StableHLO is always able to be legalized to/from the latest version of VHLO.
  + VHLO-to-version. Target previous versions of VHLO ops for forward compatibility. Upgrade to the latest version of VHLO ops to emit StableHLO.
- Testing for legalizations and version conversions.

Future work (these items will be made into individual GH issues before submit):
- Think more about a scalable way to test this as StableHLO evolves.
- Additional feature work on the tool.. any missing flags? Pass pipeline for simplicity?
- Improve user experience.
- See open [compatibility issues](https://github.com/openxla/stablehlo/labels/Compatibility) 

Closes #255
  • Loading branch information
GleasonK authored Dec 3, 2022
1 parent 8a2ce22 commit f2440c0
Show file tree
Hide file tree
Showing 34 changed files with 4,949 additions and 15 deletions.
1 change: 1 addition & 0 deletions stablehlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ add_subdirectory(integrations)
add_subdirectory(reference)
add_subdirectory(tests)
add_subdirectory(tools)
add_subdirectory(transforms)
17 changes: 8 additions & 9 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#ifndef STABLEHLO_DIALECT_BASE
#define STABLEHLO_DIALECT_BASE

include "mlir/Dialect/Quant/QuantOpsBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
Expand All @@ -44,7 +43,7 @@ def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
//===----------------------------------------------------------------------===//

// TODO(b/230381284): Upstream width-specific uniform quantized element types.
class UniformQuantizedSignedInt<int width>
class StableHLO_UniformQuantizedSignedInt<int width>
: Type<Or<[
And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedType>()">,
CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" #
Expand All @@ -61,7 +60,7 @@ class UniformQuantizedSignedInt<int width>
int bitwidth = width;
}

class UniformQuantizedUnsignedInt<int width>
class StableHLO_UniformQuantizedUnsignedInt<int width>
: Type<Or<[
And<[CPred<"$_self.isa<mlir::quant::UniformQuantizedType>()">,
CPred<"$_self.cast<mlir::quant::UniformQuantizedType>()" #
Expand All @@ -78,20 +77,20 @@ class UniformQuantizedUnsignedInt<int width>
int bitwidth = width;
}

class UniformQuantizedSignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, UniformQuantizedSignedInt<w>),
class StableHLO_UniformQuantizedSignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, StableHLO_UniformQuantizedSignedInt<w>),
!interleave(widths, "/") # "-bit uniform quantized signed " #
"integer">;

class UniformQuantizedUnsignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, UniformQuantizedUnsignedInt<w>),
class StableHLO_UniformQuantizedUnsignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, StableHLO_UniformQuantizedUnsignedInt<w>),
!interleave(widths, "/") # "-bit uniform quantized unsigned " #
"integer">;

// Integer-based uniform quantized types. The definitions can be used to specify
// operand's tensor types.
def HLO_QuantizedSignedInt : UniformQuantizedSignedIntOfWidths<[4, 8, 16, 32]>;
def HLO_QuantizedUnsignedInt : UniformQuantizedUnsignedIntOfWidths<[4, 8, 16, 32]>;
def HLO_QuantizedSignedInt : StableHLO_UniformQuantizedSignedIntOfWidths<[4, 8, 16, 32]>;
def HLO_QuantizedUnsignedInt : StableHLO_UniformQuantizedUnsignedIntOfWidths<[4, 8, 16, 32]>;
def HLO_QuantizedInt :
AnyTypeOf<[HLO_QuantizedSignedInt, HLO_QuantizedUnsignedInt]>;

Expand Down
33 changes: 31 additions & 2 deletions stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@ add_mlir_dialect_library(StablehloRegister
DEPENDS
ChloOpsIncGen
StablehloOpsIncGen
VhloOpsIncGen

LINK_LIBS PUBLIC
ChloOps
StablehloOps
VhloOps
)

add_mlir_dialect_library(StablehloAssemblyFormat
PARTIAL_SOURCES_INTENDED
AssemblyFormat.cpp
AssemblyFormat.cpp

LINK_LIBS PUBLIC
StablehloBase
Expand All @@ -101,7 +103,7 @@ add_mlir_dialect_library(StablehloAssemblyFormat

add_mlir_dialect_library(StablehloTypeInference
PARTIAL_SOURCES_INTENDED
TypeInference.cpp
TypeInference.cpp

LINK_LIBS PUBLIC
MLIRInferTypeOpInterface
Expand Down Expand Up @@ -147,3 +149,30 @@ target_include_directories(StablehloOps INTERFACE
$<BUILD_INTERFACE:${STABLEHLO_SOURCE_DIR}>
$<BUILD_INTERFACE:${STABLEHLO_BINARY_DIR}>
)

set(LLVM_TARGET_DEFINITIONS VhloOps.td)
mlir_tablegen(VhloOps.h.inc -gen-op-decls)
mlir_tablegen(VhloOps.cpp.inc -gen-op-defs)
mlir_tablegen(VhloOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(VhloOpInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(VhloEnums.h.inc -gen-enum-decls)
mlir_tablegen(VhloEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(VhloOpsIncGen)
add_dependencies(mlir-headers VhloOpsIncGen)

add_mlir_dialect_library(VhloOps
PARTIAL_SOURCES_INTENDED
VhloOps.cpp
Version.cpp

DEPENDS
VhloOpsIncGen

LINK_LIBS PUBLIC
StablehloAssemblyFormat
MLIRIR
MLIRSupport
MLIRQuantDialect
)
4 changes: 3 additions & 1 deletion stablehlo/dialect/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/VhloOps.h"

namespace mlir {
namespace stablehlo {
Expand All @@ -27,7 +28,8 @@ void registerAllDialects(mlir::DialectRegistry &registry) {
// clang-format off
registry.insert<mlir::sparse_tensor::SparseTensorDialect>();
registry.insert<mlir::chlo::ChloDialect,
mlir::stablehlo::StablehloDialect>();
mlir::stablehlo::StablehloDialect,
mlir::vhlo::VhloDialect>();
// clang-format on
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def StableHLO_DotDimensionNumbers : AttrDef<StableHLO_Dialect, "DotDimensionNumb
let hasCustomAssemblyFormat = 1;
}

def OutputOperandAlias : AttrDef<StableHLO_Dialect, "OutputOperandAlias"> {
def StableHLO_OutputOperandAlias : AttrDef<StableHLO_Dialect, "OutputOperandAlias"> {
let cppNamespace = "::mlir::stablehlo";
let mnemonic = "output_operand_alias";
let summary =
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2101,7 +2101,7 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts,
DefaultValuedOptionalAttr<
TypedArrayAttrBase<
OutputOperandAlias,
StableHLO_OutputOperandAlias,
"Aliasing attribute for outputs and operands of CustomCall">,
"{}">:$output_operand_aliases
);
Expand Down
68 changes: 68 additions & 0 deletions stablehlo/dialect/Version.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2022 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "stablehlo/dialect/Version.h"

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Diagnostics.h"

namespace mlir {
namespace vhlo {
namespace {
// Helper function for number to string.
// Precondition that numRef is a valid decimal digit.
static int64_t parseNumber(llvm::StringRef numRef) {
int64_t num;
if (numRef.getAsInteger(/*radix=*/10, num)) {
llvm_unreachable("failed to parse version number");
}
return num;
}

/// Validate version argument is `#.#.#` (ex: 0.3.0, 1.2.3, 0.123.0)
/// Returns the vector of 3 matches (major, minor, patch) if successful,
/// else returns failure.
static FailureOr<std::array<int64_t, 3>> extractVersionNumbers(
llvm::StringRef versionRef) {
llvm::Regex versionRegex("^([0-9]+)\\.([0-9]+)\\.([0-9]+)$");
llvm::SmallVector<llvm::StringRef> matches;
if (!versionRegex.match(versionRef, &matches)) {
return failure();
}
return std::array<int64_t, 3>{parseNumber(matches[1]),
parseNumber(matches[2]),
parseNumber(matches[3])};
}
} // namespace

FailureOr<Version> Version::fromString(llvm::StringRef versionRef) {
auto failOrVersionArray = extractVersionNumbers(versionRef);
if (failed(failOrVersionArray)) {
return failure();
}

auto versionArr = *failOrVersionArray;
return Version(versionArr[0], versionArr[1], versionArr[2]);
}

mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version) {
return diag << version.getMajor() << '.' << version.getMinor() << '.'
<< version.getPatch();
}

} // namespace vhlo
} // namespace mlir
66 changes: 66 additions & 0 deletions stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2022 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_DIALECT_VERSION_H
#define STABLEHLO_DIALECT_VERSION_H

#include <algorithm>
#include <cstdint>
#include <string>

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Regex.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir {
namespace vhlo {

class Version {
public:
/// Convenience method to extract major, minor, patch and create a Version
/// from a StringRef of the form `#.#.#`. Returns failure if invalid string.
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Construct Version from major, minor, patch integers.
Version(int64_t major, int64_t minor, int64_t patch)
: majorMinorPatch({major, minor, patch}) {}

int64_t getMajor() const { return majorMinorPatch[0]; }
int64_t getMinor() const { return majorMinorPatch[1]; }
int64_t getPatch() const { return majorMinorPatch[2]; }

bool operator<(Version const& other) {
// Uses lexicographical_compare
return majorMinorPatch < other.majorMinorPatch;
}
bool operator==(Version const& other) {
return majorMinorPatch == other.majorMinorPatch;
}
bool operator<=(Version const& other) {
return majorMinorPatch <= other.majorMinorPatch;
}

private:
std::array<int64_t, 3> majorMinorPatch;
};

mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version);

} // namespace vhlo
} // namespace mlir

#endif // STABLEHLO_DIALECT_VERSION_H
Loading

0 comments on commit f2440c0

Please sign in to comment.