From 69b3ebb941ec5553f0d1fd76569817769ca621bb Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 4 Feb 2025 13:03:15 -0800 Subject: [PATCH] [NFC] Add individual switches to polynomial approximation (#19697) Give more freedom when woking with PolynomialApproximationPass, by adding individual switches to each of the approximation ops so they can be fine tuned for different architectures. --------- Co-authored-by: Jakub Kuderski --- .../iree/compiler/Codegen/Common/Passes.td | 6 +++ .../Common/PolynomialApproximationPass.cpp | 50 +++++++++++++------ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index f0d9b19eff05..785e1d477b20 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -529,6 +529,12 @@ def PadDynamicAllocPass : def PolynomialApproximationPass : Pass<"iree-codegen-polynomial-approximation", ""> { let summary = "Convert math operations to their polynomial approximation"; + let options = [ + ListOption<"noApproxOps", "no-approx-ops", "std::string", + [{List of operations that should not be approximated.\n" + "As of now, possible options are:\n" + "\ttan, sinh, cosh, asinh, acosh, atanh, powf, fpowf, erf\n}]>, + ]; } def PropagateDispatchSizeBoundsPass : diff --git a/compiler/src/iree/compiler/Codegen/Common/PolynomialApproximationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/PolynomialApproximationPass.cpp index 2e9d1c68beb2..82aa080ec3cb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PolynomialApproximationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PolynomialApproximationPass.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/Passes.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -25,28 +26,47 @@ static llvm::cl::opt clNativeMathPrecision( namespace { +static void populateErfPattern(RewritePatternSet &patterns) { + if (clNativeMathPrecision) { + patterns.add(patterns.getContext()); + } else { + populateExpandExp2FPattern(patterns); + populateMathPolynomialApproximationPatterns(patterns); + populateExpandRoundEvenPattern(patterns); + } +} + /// math dialect elementry functions -> polynomial form. class PolynomialApproximationPass final : public impl::PolynomialApproximationPassBase< PolynomialApproximationPass> { +public: + using Base::Base; + void runOnOperation() override { + using PatternFunction = llvm::function_ref; + // Order matters here. + llvm::SmallVector> patternMap = { + {"tan", populateExpandTanPattern}, + {"sinh", populateExpandSinhPattern}, + {"cosh", populateExpandCoshPattern}, + {"asinh", populateExpandAsinhPattern}, + {"acosh", populateExpandAcoshPattern}, + {"atanh", populateExpandAtanhPattern}, + {"powf", populateExpandPowFPattern}, + {"fpowi", populateExpandFPowIPattern}, + {"erf", populateErfPattern}, + }; + RewritePatternSet mathPatterns(&getContext()); - populateExpandTanPattern(mathPatterns); - populateExpandSinhPattern(mathPatterns); - populateExpandCoshPattern(mathPatterns); - populateExpandAsinhPattern(mathPatterns); - populateExpandAcoshPattern(mathPatterns); - populateExpandAtanhPattern(mathPatterns); - populateExpandPowFPattern(mathPatterns); - populateExpandFPowIPattern(mathPatterns); - - if (clNativeMathPrecision) { - mathPatterns.add(&getContext()); - } else { - populateExpandExp2FPattern(mathPatterns); - populateMathPolynomialApproximationPatterns(mathPatterns); - populateExpandRoundEvenPattern(mathPatterns); + + for (const auto &[fnName, populateFn] : patternMap) { + // Skip any ops in the "do not convert" list. + if (!llvm::is_contained(noApproxOps, fnName)) { + populateFn(mathPatterns); + } } + if (failed( applyPatternsGreedily(getOperation(), std::move(mathPatterns)))) { return signalPassFailure();