-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
StableHLO Versioned Dialect and Compatibility Passes (#278)
Add StableHLO compatibility dialect and passes for reading and writing with forward/backward compatibility guarantees. **Note: This is still a prototype implementation and should not be used in production until RFCs have been approved and types have been forked.** ```bash stablehlo-opt [flags] file.mlir --stablehlo-legalize-to-vhlo --vhlo-to-version='target=[version]' Translate versioned dialect to target version (current, 0.3.0, ...) --vhlo-legalize-to-stablehlo ``` Change description: - Introduce VHLO, the Versioned StableHLO Dialect. + This dialect is a shallow copy of StableHLO's in-memory layout. It does not include verifiers or constraints. + Once an op is added to VHLO it must remain unchanged so that it can be guaranteed that a VHLO op is identical across versions. + The first version of VHLO is `0.3.0`. - Conversion passes for compatibility + StableHLO <--> VHLO legalizations. StableHLO is always able to be legalized to/from the latest version of VHLO. + VHLO-to-version. Target previous versions of VHLO ops for forward compatibility. Upgrade to the latest version of VHLO ops to emit StableHLO. - Testing for legalizations and version conversions. Future work (these items will be made into individual GH issues before submit): - Think more about a scalable way to test this as StableHLO evolves. - Additional feature work on the tool.. any missing flags? Pass pipeline for simplicity? - Improve user experience. - See open [compatibility issues](https://github.com/openxla/stablehlo/labels/Compatibility) Closes #255
- Loading branch information
Showing
34 changed files
with
4,949 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Copyright 2022 The StableHLO Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "stablehlo/dialect/Version.h" | ||
|
||
#include "llvm/ADT/StringRef.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
|
||
namespace mlir { | ||
namespace vhlo { | ||
namespace { | ||
// Helper function for number to string. | ||
// Precondition that numRef is a valid decimal digit. | ||
static int64_t parseNumber(llvm::StringRef numRef) { | ||
int64_t num; | ||
if (numRef.getAsInteger(/*radix=*/10, num)) { | ||
llvm_unreachable("failed to parse version number"); | ||
} | ||
return num; | ||
} | ||
|
||
/// Validate version argument is `#.#.#` (ex: 0.3.0, 1.2.3, 0.123.0) | ||
/// Returns the vector of 3 matches (major, minor, patch) if successful, | ||
/// else returns failure. | ||
static FailureOr<std::array<int64_t, 3>> extractVersionNumbers( | ||
llvm::StringRef versionRef) { | ||
llvm::Regex versionRegex("^([0-9]+)\\.([0-9]+)\\.([0-9]+)$"); | ||
llvm::SmallVector<llvm::StringRef> matches; | ||
if (!versionRegex.match(versionRef, &matches)) { | ||
return failure(); | ||
} | ||
return std::array<int64_t, 3>{parseNumber(matches[1]), | ||
parseNumber(matches[2]), | ||
parseNumber(matches[3])}; | ||
} | ||
} // namespace | ||
|
||
FailureOr<Version> Version::fromString(llvm::StringRef versionRef) { | ||
auto failOrVersionArray = extractVersionNumbers(versionRef); | ||
if (failed(failOrVersionArray)) { | ||
return failure(); | ||
} | ||
|
||
auto versionArr = *failOrVersionArray; | ||
return Version(versionArr[0], versionArr[1], versionArr[2]); | ||
} | ||
|
||
mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version) { | ||
return diag << version.getMajor() << '.' << version.getMinor() << '.' | ||
<< version.getPatch(); | ||
} | ||
|
||
} // namespace vhlo | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* Copyright 2022 The StableHLO Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef STABLEHLO_DIALECT_VERSION_H | ||
#define STABLEHLO_DIALECT_VERSION_H | ||
|
||
#include <algorithm> | ||
#include <cstdint> | ||
#include <string> | ||
|
||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/StringRef.h" | ||
#include "llvm/Support/Regex.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
|
||
namespace mlir { | ||
namespace vhlo { | ||
|
||
class Version { | ||
public: | ||
/// Convenience method to extract major, minor, patch and create a Version | ||
/// from a StringRef of the form `#.#.#`. Returns failure if invalid string. | ||
static FailureOr<Version> fromString(llvm::StringRef versionRef); | ||
|
||
/// Construct Version from major, minor, patch integers. | ||
Version(int64_t major, int64_t minor, int64_t patch) | ||
: majorMinorPatch({major, minor, patch}) {} | ||
|
||
int64_t getMajor() const { return majorMinorPatch[0]; } | ||
int64_t getMinor() const { return majorMinorPatch[1]; } | ||
int64_t getPatch() const { return majorMinorPatch[2]; } | ||
|
||
bool operator<(Version const& other) { | ||
// Uses lexicographical_compare | ||
return majorMinorPatch < other.majorMinorPatch; | ||
} | ||
bool operator==(Version const& other) { | ||
return majorMinorPatch == other.majorMinorPatch; | ||
} | ||
bool operator<=(Version const& other) { | ||
return majorMinorPatch <= other.majorMinorPatch; | ||
} | ||
|
||
private: | ||
std::array<int64_t, 3> majorMinorPatch; | ||
}; | ||
|
||
mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version); | ||
|
||
} // namespace vhlo | ||
} // namespace mlir | ||
|
||
#endif // STABLEHLO_DIALECT_VERSION_H |
Oops, something went wrong.