From 3946bb1775f9482f5b0f10bc355af5d12e7ec99b Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Wed, 29 Jan 2025 18:01:34 +0530 Subject: [PATCH] Bump llvm to llvm/llvm-project@d4159e2a1d1d --- externals/llvm-project | 2 +- externals/stablehlo | 2 +- .../TorchToTosa/TosaLegalizeUtils.h | 9 +- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 165 +++++++++++------- .../TorchToTosa/TosaLegalizeCommon.cpp | 16 +- .../TorchToTosa/TosaLegalizeUtils.cpp | 56 +++--- lib/Dialect/Torch/Utils/Utils.cpp | 8 +- 7 files changed, 150 insertions(+), 108 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index e2402615a5a7..d4159e2a1d1d 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit e2402615a5a76d46a433dfcc1de10b38a1263c9d +Subproject commit d4159e2a1d1d640077b2e5cde66b0a284049955f diff --git a/externals/stablehlo b/externals/stablehlo index 8cd9444b78cc..48a1e14edc82 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 8cd9444b78ccec3e42a4b21105a5a547c021e823 +Subproject commit 48a1e14edc8219577fcad53de1924876f855f431 diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 15f29fbc3cab..44f8b0fe4d17 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type); Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, float val); +// Create a 8-bit int constant operator from a int +Value getTosaConstTensorSingleI8(PatternRewriter &rewriter, Operation *op, + int32_t val); + // Create a zero constant tensor of the desired type and shape. std::optional getZerosLikeTensor(PatternRewriter &rewriter, Operation *op, Type type); @@ -127,11 +131,6 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, RankedTensorType weightTy, RankedTensorType outputTy, TypeAttr &accType); -// Temporary function to get TOSA const shape -// TODO: Remove this function when getTosaConstShape is available in -// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h -Value getTosaConstShape(PatternRewriter &rewriter, Location loc, - llvm::ArrayRef shape); } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 4ec703d892ad..0dd3db85994a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -622,14 +622,16 @@ Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType, auto boolType = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)); - auto lhsMulRhs = rewriter.create(op->getLoc(), i32Type, lhs, rhs, - /*shift=*/0); + auto lhsMulRhs = rewriter.create( + op->getLoc(), i32Type, lhs, rhs, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto lhsRhsDifferentSign = rewriter.create(op->getLoc(), boolType, zero, lhsMulRhs); - auto truncMulRhs = rewriter.create(op->getLoc(), i32Type, - intDivOp, rhs, /*shift=*/0); + auto truncMulRhs = rewriter.create( + op->getLoc(), i32Type, intDivOp, rhs, + tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto truncMulRhsEqualLhs = rewriter.create(op->getLoc(), boolType, truncMulRhs, lhs); @@ -853,7 +855,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( self, zero); auto mulTensor = rewriter.create( op->getLoc(), getTypeConverter()->convertType(op.getType()), self, - alphaTensor, /*shift=*/0); + alphaTensor, tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), cond, self, mulTensor); @@ -2151,8 +2153,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( /*checkForUnity=*/true))) return failure(); - auto multTensor = rewriter.create(op->getLoc(), resultTy, self, - alphaTensor, /*shift=*/0); + auto multTensor = rewriter.create( + op->getLoc(), resultTy, self, alphaTensor, + tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); rewriter.replaceOpWithNewOp(op, resultTy, otherTensor, multTensor); @@ -2493,12 +2496,14 @@ Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter, auto op3RsqrtOp2 = rewriter.create( op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult()); - auto op4MulOp1Op3 = rewriter.create(op->getLoc(), outType, - op1SubInputMean.getResult(), - op3RsqrtOp2.getResult(), 0); + auto op4MulOp1Op3 = rewriter.create( + op->getLoc(), outType, op1SubInputMean.getResult(), + op3RsqrtOp2.getResult(), + tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto op5MulOp4Scale = rewriter.create( - op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0); + op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, + tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); return rewriter .create(op->getLoc(), outType, op5MulOp4Scale.getResult(), @@ -2710,19 +2715,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Compute mean. Value sum = computeSumAndReshape(adaptor.getInput(), inputType, bcastOutType, bcastOutShape); - Value meanVal = rewriter.create(op.getLoc(), bcastOutType, sum, - elemCntRcp, /*shift=*/0); + Value meanVal = rewriter.create( + op.getLoc(), bcastOutType, sum, elemCntRcp, + tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); // Compute variance. Value squareSumSub = rewriter.create( op.getLoc(), inputType, adaptor.getInput(), meanVal); - Value squareSum = rewriter.create(op.getLoc(), inputType, - squareSumSub, squareSumSub, 0); + Value squareSum = rewriter.create( + op.getLoc(), inputType, squareSumSub, squareSumSub, + tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); Value squareSumReduced = computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape); Value varianceVal = rewriter.create( - op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0); + op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); // Reshape weight and bias. SmallVector weightAndBiasBcastShape; @@ -2978,8 +2986,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); auto logOp = rewriter.create(op.getLoc(), outType, self); - rewriter.replaceOpWithNewOp(op, outType, logOp, rcpOp, - /*shift=*/0); + rewriter.replaceOpWithNewOp( + op, outType, logOp, rcpOp, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); return success(); } @@ -3195,32 +3204,48 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto a1 = tosa::getConstTensor(rewriter, op, 0.278393f, {}, dtype).value(); - auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); + auto a1X = rewriter.create( + loc, outType, a1, absX, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto sum = rewriter.create(loc, outType, a1X, one); auto a2 = tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); - auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); - auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); + auto x2 = rewriter.create( + loc, outType, absX, absX, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); + auto a2X = rewriter.create( + loc, outType, a2, x2, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); sum = rewriter.create(loc, outType, sum, a2X); auto a3 = tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); - auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); - auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); + auto x3 = rewriter.create( + loc, outType, x2, absX, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); + auto a3X = rewriter.create( + loc, outType, a3, x3, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); sum = rewriter.create(loc, outType, sum, a3X); auto a4 = tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); - auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); - auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); + auto x4 = rewriter.create( + loc, outType, x3, absX, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); + auto a4X = rewriter.create( + loc, outType, a4, x4, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); sum = rewriter.create(loc, outType, sum, a4X); auto rcprl = rewriter.create(loc, outType, sum); - auto rcprl2 = - rewriter.create(loc, outType, rcprl, rcprl, /*shift=*/0); - auto rcprl4 = - rewriter.create(loc, outType, rcprl2, rcprl2, /*shift=*/0); + auto rcprl2 = rewriter.create( + loc, outType, rcprl, rcprl, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); + auto rcprl4 = rewriter.create( + loc, outType, rcprl2, rcprl2, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto erf = rewriter.create(loc, outType, one, rcprl4); // Deal with negative x. @@ -3248,15 +3273,17 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value rsqrt2 = tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); - Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, - /*shift=*/0); + Value erfArg = rewriter.create( + loc, outType, xMinusMean, rsqrt2, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); Value erf = approximateErfOp(rewriter, op, erfArg, dtype); Value erfPlus1 = rewriter.create(loc, outType, one, erf); Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); - Value normalCdf = rewriter.create(loc, outType, oneHalf, - erfPlus1, /*shift=*/0); + Value normalCdf = rewriter.create( + loc, outType, oneHalf, erfPlus1, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); return normalCdf; } @@ -3295,8 +3322,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); - rewriter.replaceOpWithNewOp(op, resultType, self, cdf, - /*shift=*/0); + rewriter.replaceOpWithNewOp( + op, resultType, self, cdf, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); } else if (approximate.compare("tanh") == 0) { // "tanh" approximate // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) @@ -3337,8 +3365,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); // 0.5 * x - auto halfInput = rewriter.create(op->getLoc(), resultType, - half, self, /*shift=*/0); + auto halfInput = rewriter.create( + op->getLoc(), resultType, half, self, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); // sqrt(2/pi) auto sqrtTwoOverPi = @@ -3349,9 +3378,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), resultType, self, three); // 0.044715 * x^3 - auto inputPowThreeMul = - rewriter.create(op->getLoc(), resultType, magicNumber, - inputPowThree.getResult(), /*shift=*/0); + auto inputPowThreeMul = rewriter.create( + op->getLoc(), resultType, magicNumber, inputPowThree.getResult(), + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); // x + 0.044715 * x^3 auto inputPowThreeMulAdd = rewriter.create( @@ -3360,7 +3389,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // sqrt(2/pi) * (x + 0.044715 * x^3) auto sqrtTwoOverPiMul = rewriter.create( op->getLoc(), resultType, sqrtTwoOverPi.getResult(), - inputPowThreeMulAdd.getResult(), /*shift=*/0); + inputPowThreeMulAdd.getResult(), + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); // tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) auto tanh = rewriter.create(op->getLoc(), resultType, @@ -3372,7 +3402,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, resultType, halfInput.getResult(), tanhAdd.getResult(), - /*shift=*/0); + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); } else { return rewriter.notifyMatchFailure(op, "Unsupported approximation algorithm"); @@ -3419,22 +3449,26 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value negOneHalf = tosa::getConstTensor(rewriter, op, -0.5f, {}, selfElemTy).value(); Value inputSquared = rewriter.create( - loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); + loc, selfType, adaptor.getSelf(), adaptor.getSelf(), + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); Value negHalfInputSquared = rewriter.create( - loc, selfType, inputSquared, negOneHalf, /*shift=*/0); + loc, selfType, inputSquared, negOneHalf, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); Value dinput = rewriter.create(loc, selfType, negHalfInputSquared); Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); Value dinputInput = rewriter.create( - loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0); + loc, selfType, dinput, adaptor.getSelf(), + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); Value dinputInputAlpha = rewriter.create( - loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0); + loc, selfType, dinputInput, kAlphaHalf, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); Value cdfExt = rewriter.create(loc, selfType, dinputInputAlpha, cdf); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getGradOutput(), cdfExt, - /*shift=*/0); + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); return success(); } @@ -4828,8 +4862,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), otherType, adaptor.getOther()); auto rtolConstOp = tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); - auto mulOp = rewriter.create(op->getLoc(), otherType, - rtolConstOp, lhsAbsOp, /*shift=*/0); + auto mulOp = rewriter.create( + op->getLoc(), otherType, rtolConstOp, lhsAbsOp, + tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto atolConstOp = tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); auto addOp = @@ -5354,7 +5389,8 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { auto otherTensorReciprocal = rewriter.create( op.getLoc(), otherTensor.getType(), otherTensor); divTensor = rewriter.create( - op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); + op.getLoc(), outType, self, otherTensorReciprocal, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); divTensor = rewriter.create(op.getLoc(), outType, divTensor); } else { @@ -5378,9 +5414,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { } } - auto mulTensor = rewriter.create(op.getLoc(), outType, - otherTensor, divTensor, - /*shift=*/0); + auto mulTensor = rewriter.create( + op.getLoc(), outType, otherTensor, divTensor, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); return success(); @@ -6572,8 +6608,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Invalid integer width"); }); - rewriter.replaceOpWithNewOp(op, resultType, self, trilMask, - /*shift=*/0); + rewriter.replaceOpWithNewOp( + op, resultType, self, trilMask, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); return success(); } @@ -6663,14 +6700,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); auto floorInputDivByTwo = rewriter.create( - op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + op->getLoc(), resultTy, floorInput.getResult(), oneHalf, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto floorDivResult = rewriter.create( op->getLoc(), resultTy, floorInputDivByTwo.getResult()); // (floor(input) // 2) * 2 auto evenComparison = rewriter.create( - op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0); + op->getLoc(), resultTy, floorDivResult.getResult(), two, + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 auto floorInputEven = rewriter.create( @@ -6849,7 +6888,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value diagonalTensor = rewriter.create( op->getLoc(), transposedInputType, selfTransposed, diagonalMask, - /*shift=*/0); + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto resultShape = makeShapeTorchCompatible(resultType.getShape()); auto targetReduceDim = resultShape[resultType.getRank() - 1]; @@ -8127,9 +8166,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneMinusZiReciprocal = rewriter.create( op->getLoc(), resultType, oneMinusZi.getResult()); - auto mulOp = rewriter.create(op->getLoc(), resultType, zi, - oneMinusZiReciprocal.getResult(), - /*shift=*/0); + auto mulOp = rewriter.create( + op->getLoc(), resultType, zi, oneMinusZiReciprocal.getResult(), + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); auto result = rewriter.create(op->getLoc(), resultType, mulOp.getResult()); @@ -8220,7 +8259,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto result = rewriter.create( op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(), - /*shift=*/0); + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); rewriter.replaceOp(op, {result.getResult()}); @@ -8301,7 +8340,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto result = rewriter.create( op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(), - /*shift=*/0); + /*shift=*/tosa::getTosaConstTensorSingleI8(rewriter, op, 0)); rewriter.replaceOp(op, {result.getResult()}); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 9dedf457096a..c94854eefbf6 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -119,8 +119,9 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, int32_t shift) { lhs = promoteType(rewriter, lhs, outType); rhs = promoteType(rewriter, rhs, outType); - return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, - lhs, rhs, shift); + return tosa::CreateOpAndInfer( + rewriter, op->getLoc(), outType, lhs, rhs, + getTosaConstTensorSingleI8(rewriter, op, shift)); } template <> @@ -384,7 +385,8 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0); + indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), + getTosaConstTensorSingleI8(rewriter, op, 0)); // Sum up the products of the coefficients and coordinates // %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) -> @@ -650,7 +652,8 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0); + indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), + getTosaConstTensorSingleI8(rewriter, op, 0)); // Sum up the products of the coefficients and coordinates // [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]] @@ -973,8 +976,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, if (!input_is_qtype) { Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); - return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - val.value(), div_const, 0) + return CreateOpAndInfer( + rewriter, op->getLoc(), output_type, val.value(), div_const, + getTosaConstTensorSingleI8(rewriter, op, 0)) .getResult(); } diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 1ed360ddae61..023d6b08095a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -150,6 +150,18 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, return const_op.getResult(); } +// Create a 8-bit int constant operator from a int +Value getTosaConstTensorSingleI8(PatternRewriter &rewriter, Operation *op, + int32_t val) { + auto shiftElementType = IntegerType::get(rewriter.getContext(), 8); + auto shiftType = RankedTensorType::get({1}, shiftElementType); + auto shiftZeroAttr = DenseElementsAttr::get( + shiftType, rewriter.getIntegerAttr(shiftElementType, val)); + Value constVal = + rewriter.create(op->getLoc(), shiftType, shiftZeroAttr); + return constVal; +} + // Create a zero constant tensor of the desired type and shape. std::optional getZerosLikeTensor(PatternRewriter &rewriter, Operation *op, Type type) { @@ -301,31 +313,31 @@ std::optional getConstTensor(PatternRewriter &rewriter, (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isBF16()) || (src.isF32() && dest.isF16()) || - (src.isF32() && dest.isFloat8E4M3()) || - (src.isF32() && dest.isFloat8E5M2()) || + (src.isF32() && isa(dest)) || + (src.isF32() && isa(dest)) || // f16 -> * (src.isF16() && dest.isInteger(32)) || (src.isF16() && dest.isInteger(16)) || (src.isF16() && dest.isInteger(8)) || (src.isF16() && dest.isBF16()) || (src.isF16() && dest.isF32()) || - (src.isF16() && dest.isFloat8E4M3()) || - (src.isF16() && dest.isFloat8E5M2()) || + (src.isF16() && isa(dest)) || + (src.isF16() && isa(dest)) || // bf16 -> * (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isInteger(16)) || (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isF32()) || - (src.isBF16() && dest.isFloat8E4M3()) || - (src.isBF16() && dest.isFloat8E5M2()) || + (src.isBF16() && isa(dest)) || + (src.isBF16() && isa(dest)) || // fp8e4m3 -> * - (src.isFloat8E4M3() && dest.isBF16()) || - (src.isFloat8E4M3() && dest.isF32()) || - (src.isFloat8E4M3() && dest.isF16()) || + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16()) || // fp8e5m2 -> * - (src.isFloat8E5M2() && dest.isBF16()) || - (src.isFloat8E5M2() && dest.isF32()) || - (src.isFloat8E5M2() && dest.isF16())) { + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16())) { return success(); } // clang-format on @@ -488,10 +500,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && outputElemTy.isInteger(48)) { accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); - } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && - outputElemTy.isF16()) || - (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && - outputElemTy.isF16())) { + } else if ((isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16()) || + (isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16())) { accType = mlir::TypeAttr::get(rewriter.getF16Type()); } else { accType = mlir::TypeAttr::get(outputElemTy); @@ -500,17 +512,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, return success(); } -// Temporary function to get TOSA const shape -// TODO: Remove this function when getTosaConstShape is available in -// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h -Value getTosaConstShape(PatternRewriter &rewriter, Location loc, - llvm::ArrayRef shape) { - auto attr = rewriter.getIndexTensorAttr(shape); - auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size()); - mlir::Operation *mlir_op = - rewriter.create(loc, type, attr); - return mlir_op->getResult(0); -} - } // namespace tosa } // namespace mlir diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index c0984efffd9c..7f80e84044df 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getF32Type(); if (isa(inputType)) return rewriter.getF64Type(); - if (inputType.isFloat8E5M2()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FN()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E5M2FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); if (inputType.isInteger(8)) // this is an intentional deviation from CUDA (which accumulates i8 to i64)