Skip to content

Commit

Permalink
[TOSA] Update format with pre-commit run
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Ngo <[email protected]>
Change-Id: I1ad7276cf2985ebb2e2a52fdfe596d6a4f28125c
  • Loading branch information
justin-ngo-arm committed Jan 29, 2025
1 parent f7d83a8 commit 44ad48a
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project

namespace mlir {
Expand Down Expand Up @@ -394,7 +394,8 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
auto zeroValue =
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue).failed())
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
.failed())
return rewriter.notifyMatchFailure(
op, "Failed to equalize ranks among operands and result");

Expand Down Expand Up @@ -493,10 +494,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
outputElemTy.isInteger(48)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
} else if ((isa<Float8E4M3Type>(inputElemTy) && isa<Float8E4M3Type>(weightElemTy) &&
outputElemTy.isF16()) ||
(isa<Float8E5M2Type>(inputElemTy) && isa<Float8E5M2Type>(weightElemTy) &&
outputElemTy.isF16())) {
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
(isa<Float8E5M2Type>(inputElemTy) &&
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
accType = mlir::TypeAttr::get(rewriter.getF16Type());
} else {
accType = mlir::TypeAttr::get(outputElemTy);
Expand Down

0 comments on commit 44ad48a

Please sign in to comment.