diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 310b4b56e851..3e4e6089389a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -8,8 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project namespace mlir { @@ -394,7 +394,8 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); - if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue).failed()) + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) + .failed()) return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); @@ -493,10 +494,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter, } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && outputElemTy.isInteger(48)) { accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); - } else if ((isa(inputElemTy) && isa(weightElemTy) && - outputElemTy.isF16()) || - (isa(inputElemTy) && isa(weightElemTy) && - outputElemTy.isF16())) { + } else if ((isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16()) || + (isa(inputElemTy) && + isa(weightElemTy) && outputElemTy.isF16())) { accType = mlir::TypeAttr::get(rewriter.getF16Type()); } else { accType = mlir::TypeAttr::get(outputElemTy);