Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
pashu123 committed Jan 29, 2025
1 parent 4b9b972 commit 3946bb1
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 108 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 4692 files
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 52 files
+1 −1 BUILD.bazel
+2 −2 WORKSPACE.bazel
+1 −1 build_tools/llvm_version.txt
+17 −1 docs/awesome.md
+3 −16 stablehlo/conversions/linalg/transforms/Rewriters.h
+39 −38 stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+23 −0 stablehlo/dialect/AssemblyFormat.cpp
+59 −0 stablehlo/dialect/AssemblyFormat.h
+17 −0 stablehlo/dialect/Base.cpp
+3 −0 stablehlo/dialect/Base.h
+17 −0 stablehlo/dialect/Base.td
+1 −1 stablehlo/dialect/CMakeLists.txt
+15 −0 stablehlo/dialect/StablehloAttrs.td
+79 −7 stablehlo/dialect/StablehloBytecode.cpp
+23 −0 stablehlo/dialect/StablehloEnums.td
+38 −0 stablehlo/dialect/StablehloOps.cpp
+19 −2 stablehlo/dialect/StablehloOps.td
+29 −4 stablehlo/dialect/TypeInference.cpp
+9 −0 stablehlo/dialect/TypeInference.h
+3 −3 stablehlo/dialect/Version.cpp
+1 −1 stablehlo/dialect/Version.h
+24 −11 stablehlo/dialect/VhloAttrs.td
+74 −1 stablehlo/dialect/VhloBytecode.cpp
+1 −0 stablehlo/dialect/VhloDialect.td
+33 −1 stablehlo/dialect/VhloEnums.td
+9 −8 stablehlo/dialect/VhloOps.cpp
+8 −1 stablehlo/dialect/VhloOps.td
+67 −0 stablehlo/integrations/c/StablehloAttributes.cpp
+37 −0 stablehlo/integrations/c/StablehloAttributes.h
+44 −0 stablehlo/integrations/python/StablehloModule.cpp
+21 −0 stablehlo/integrations/python/tests/stablehlo.py
+6 −7 stablehlo/reference/Types.cpp
+40 −0 stablehlo/tests/ops_stablehlo.mlir
+63 −0 stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
+5 −0 stablehlo/tests/ops_stablehlo_roundtrip.mlir
+13 −0 stablehlo/tests/print_stablehlo.mlir
+11 −0 stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+2,966 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir.bc
+31 −1 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+26 −0 stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
+24 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir
+22 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir
+1 −1 stablehlo/transforms/MapStablehloToVhlo.h
+3 −3 stablehlo/transforms/PassUtils.h
+5 −0 stablehlo/transforms/Passes.h
+20 −2 stablehlo/transforms/StablehloAggressiveSimplification.cpp
+6 −3 stablehlo/transforms/StablehloComplexMathExpanderPatterns.td
+24 −0 stablehlo/transforms/StablehloLegalizeToVhlo.cpp
+24 −0 stablehlo/transforms/VhloLegalizeToStablehlo.cpp
+53 −0 stablehlo/transforms/VhloToVersion.cpp
+16 −0 stablehlo/transforms/VhloToVersionPatterns.td
9 changes: 4 additions & 5 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);

// Create a 8-bit int constant operator from a int
Value getTosaConstTensorSingleI8(PatternRewriter &rewriter, Operation *op,
int32_t val);

// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type);
Expand Down Expand Up @@ -127,11 +131,6 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
RankedTensorType weightTy,
RankedTensorType outputTy, TypeAttr &accType);

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
} // namespace tosa
} // namespace mlir

Expand Down
165 changes: 102 additions & 63 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp

Large diffs are not rendered by default.

16 changes: 10 additions & 6 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
int32_t shift) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
lhs, rhs, shift);
return tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), outType, lhs, rhs,
getTosaConstTensorSingleI8(rewriter, op, shift));
}

template <>
Expand Down Expand Up @@ -384,7 +385,8 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(),
getTosaConstTensorSingleI8(rewriter, op, 0));

// Sum up the products of the coefficients and coordinates
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
Expand Down Expand Up @@ -650,7 +652,8 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(),
getTosaConstTensorSingleI8(rewriter, op, 0));

// Sum up the products of the coefficients and coordinates
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
Expand Down Expand Up @@ -973,8 +976,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,

if (!input_is_qtype) {
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
val.value(), div_const, 0)
return CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(), output_type, val.value(), div_const,
getTosaConstTensorSingleI8(rewriter, op, 0))
.getResult();
}

Expand Down
56 changes: 28 additions & 28 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,18 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
return const_op.getResult();
}

// Create a 8-bit int constant operator from a int
Value getTosaConstTensorSingleI8(PatternRewriter &rewriter, Operation *op,
int32_t val) {
auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
auto shiftType = RankedTensorType::get({1}, shiftElementType);
auto shiftZeroAttr = DenseElementsAttr::get(
shiftType, rewriter.getIntegerAttr(shiftElementType, val));
Value constVal =
rewriter.create<tosa::ConstOp>(op->getLoc(), shiftType, shiftZeroAttr);
return constVal;
}

// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type) {
Expand Down Expand Up @@ -301,31 +313,31 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isFloat8E4M3()) ||
(src.isF32() && dest.isFloat8E5M2()) ||
(src.isF32() && isa<Float8E4M3Type>(dest)) ||
(src.isF32() && isa<Float8E5M2Type>(dest)) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && dest.isFloat8E4M3()) ||
(src.isF16() && dest.isFloat8E5M2()) ||
(src.isF16() && isa<Float8E4M3Type>(dest)) ||
(src.isF16() && isa<Float8E5M2Type>(dest)) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && dest.isFloat8E4M3()) ||
(src.isBF16() && dest.isFloat8E5M2()) ||
(src.isBF16() && isa<Float8E4M3Type>(dest)) ||
(src.isBF16() && isa<Float8E5M2Type>(dest)) ||
// fp8e4m3 -> *
(src.isFloat8E4M3() && dest.isBF16()) ||
(src.isFloat8E4M3() && dest.isF32()) ||
(src.isFloat8E4M3() && dest.isF16()) ||
(isa<Float8E4M3Type>(src) && dest.isBF16()) ||
(isa<Float8E4M3Type>(src) && dest.isF32()) ||
(isa<Float8E4M3Type>(src) && dest.isF16()) ||
// fp8e5m2 -> *
(src.isFloat8E5M2() && dest.isBF16()) ||
(src.isFloat8E5M2() && dest.isF32()) ||
(src.isFloat8E5M2() && dest.isF16())) {
(isa<Float8E5M2Type>(src) && dest.isBF16()) ||
(isa<Float8E5M2Type>(src) && dest.isF32()) ||
(isa<Float8E5M2Type>(src) && dest.isF16())) {
return success();
}
// clang-format on
Expand Down Expand Up @@ -488,10 +500,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
outputElemTy.isInteger(48)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
outputElemTy.isF16()) ||
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
outputElemTy.isF16())) {
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
(isa<Float8E5M2Type>(inputElemTy) &&
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
accType = mlir::TypeAttr::get(rewriter.getF16Type());
} else {
accType = mlir::TypeAttr::get(outputElemTy);
Expand All @@ -500,17 +512,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
return success();
}

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape) {
auto attr = rewriter.getIndexTensorAttr(shape);
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
mlir::Operation *mlir_op =
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
return mlir_op->getResult(0);
}

} // namespace tosa
} // namespace mlir
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getF32Type();
if (isa<Float64Type>(inputType))
return rewriter.getF64Type();
if (inputType.isFloat8E5M2())
if (isa<Float8E5M2Type>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FN())
if (isa<Float8E4M3FNType>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FNUZ())
if (isa<Float8E4M3FNUZType>(inputType))
return rewriter.getF32Type();
if (inputType.isInteger(8))
// this is an intentional deviation from CUDA (which accumulates i8 to i64)
Expand Down

0 comments on commit 3946bb1

Please sign in to comment.