diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index a21c9a1d7296..7ed50e866492 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -27,7 +27,7 @@ runs: steps: - name: Set up Python if: ${{ runner.arch == 'X64' }} - uses: actions/setup-python@v4 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.11' @@ -74,7 +74,7 @@ runs: - name: Enable ccache if: ${{ inputs.cache-enabled == 'true' }} - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ github.workspace }}/.ccache key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 3c8b95a3181a..8c571893e145 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -22,7 +22,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'false' token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -95,7 +95,7 @@ jobs: - name: Post issue comment on build failure if: failure() - uses: peter-evans/create-or-update-comment@v2 + uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0 with: issue-number: 1690 body: | @@ -111,7 +111,7 @@ jobs: - name: Update PyTorch Build Cache (if running on main branch) if: github.ref_name == 'main' id: cache-pytorch - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} @@ -127,7 +127,7 @@ jobs: git pull origin main - name: Create pull request - uses: peter-evans/create-pull-request@v5.0.1 + uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f # v7.0.5 with: author: Roll PyTorch Action branch: rollpytorch diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 23f2addbe5af..4eeef0b9bb5e 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -22,7 +22,7 @@ concurrency: jobs: ubuntu-build: name: ubuntu-x86_64 - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Prepare workspace @@ -32,7 +32,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checkout torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' @@ -40,7 +40,7 @@ jobs: # restore to avoid the cache going stale over time # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache - name: Setup cache for bazel - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ~/.cache/bazel key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} @@ -102,7 +102,7 @@ jobs: - name: Send mail if: failure() - uses: dawidd6/action-send-mail@v3 + uses: dawidd6/action-send-mail@2cea9617b09d79a095af21254fbcb7ae95903dde # v3.12.0 with: server_address: ${{ secrets.SMTP_SERVER }} server_port: ${{ secrets.SMTP_PORT }} diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index e84aabb4b388..a304672b474f 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -28,7 +28,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' fetch-depth: 0 @@ -59,7 +59,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -75,7 +75,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist @@ -96,7 +96,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' fetch-depth: 0 @@ -127,7 +127,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -143,7 +143,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist @@ -156,7 +156,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -187,7 +187,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -203,7 +203,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist @@ -216,7 +216,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -250,7 +250,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -267,13 +267,13 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist publish_releases: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 needs: - build_linux - build_linux_arm64 @@ -285,7 +285,7 @@ jobs: steps: - name: Invoke Publish Releases Page - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Publish releases page token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index a0eb45257b11..e87630edb28c 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -8,7 +8,7 @@ on: jobs: scrape_and_publish_releases: name: "Scrape and publish releases" - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' @@ -20,7 +20,7 @@ jobs: # existing lock files. sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: Run scrape releases script diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index 58a91fd1d409..e335f1fdfd7d 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -9,7 +9,7 @@ on: jobs: merge-pr: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 if: | github.repository == 'llvm/torch-mlir' && github.event.workflow_run.actor.login == 'stellaraccident' && @@ -18,7 +18,7 @@ jobs: steps: # Fetch the repo first so that the gh command knows where to look for the PR - name: Fetch Repo - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index ec1878606624..92d732cea3a6 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -18,7 +18,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -43,16 +43,15 @@ jobs: - name: Create Release id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 with: - tag_name: ${{ env.tag_name }} - release_name: torch-mlir snapshot ${{ env.tag_name }} + tag: ${{ env.tag_name }} + name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: "Invoke workflow :: Build and Test" uses: benc-uk/workflow-dispatch@v1 diff --git a/.github/workflows/pre-commit-all.yml b/.github/workflows/pre-commit-all.yml index e17d4ebdbb43..2c0d61e92747 100644 --- a/.github/workflows/pre-commit-all.yml +++ b/.github/workflows/pre-commit-all.yml @@ -6,10 +6,10 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --color=always --all-files diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 29733c2e5d45..6a848fe8674f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -5,13 +5,13 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: # requites to grab the history of the PR fetch-depth: 0 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --color=always --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index 8a0ec914440f..7b575764ac8e 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -9,7 +9,7 @@ on: jobs: release_snapshot_package: name: "Tag snapshot release" - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' steps: @@ -21,7 +21,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -46,26 +46,25 @@ jobs: - name: Create Release id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 with: - tag_name: ${{ env.tag_name }} - release_name: torch-mlir snapshot ${{ env.tag_name }} + tag: ${{ env.tag_name }} + name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: "Invoke workflow :: Build and Test" - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Build and Test token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} ref: "${{ env.tag_name }}" - name: "Invoke workflow :: Release Build" - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Release Build token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db1..221745b1c26e 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -12,12 +12,25 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + #include namespace mlir { namespace torch { + +/// Collect a set of legal/illegal ops for converting Torch operations to Tosa +/// dialect. +void populateTorchToTosaConversionLegalOps(ConversionTarget &target); + +/// Collect a set of patterns to convert Torch operations to Tosa dialect + +/// return the set of illegalOps +std::set +populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, + RewritePatternSet &patterns); + std::unique_ptr> createConvertTorchToTosaPass(); -} +} // namespace torch } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index d21dd5504dcd..264fb4966d39 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize); +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c33a9d717eac..c41c90aa2b0d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4555,6 +4555,29 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ }]; } +def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::special_expm1 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSpecialExpm1Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSpecialExpm1Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7446b7faaa08..12d8683bc9d1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1093,18 +1093,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.replaceOp(binder.op, nllLoss); return success(); }); - patterns.onOp("NonZero", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) { - return failure(); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + auto rawSize = resultType.getSizes(); + SmallVector torchResultSize(rawSize.rbegin(), rawSize.rend()); + auto torchResultType = rewriter.getType( + torchResultSize, resultType.getDtype()); + auto nonZero = rewriter.create( + binder.getLoc(), torchResultType, operand); + // The output tensor has a shape of ((n, z)), where (n) is the + // number of dimensions in the input tensor and (z) is the + // number of non-zero elements2. This is different from + // PyTorch's default behavior, where the dimensions are + // reversed. + rewriter.replaceOpWithNewOp( + binder.op, resultType, nonZero, zero, one); + return success(); + }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; @@ -3671,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( patterns.onOp( "NonMaxSuppression", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; SmallVector operands; int64_t centerPointBox; @@ -3685,34 +3703,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "unimplemented: expected center_point_box " "attribute value to be 0"); - // TODO: Add support for optional arguments to be absent. - if (operands.size() < 4) - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: expected at least 4 arguments"); - + // TODO: Support multiple batches and classes // Squeeze the boxes and scores tensor. // In Onnx, the shape of boxes is [BxNx4] while the // torchvision expects it to be of shape [Nx4]. Similarly, for // the scores tensor shape in Onnx is [BxCxN] while the // torchvision expects it to be of shape [N]. Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, boxes); + FailureOr squeezedBoxes = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); if (failed(squeezedBoxes)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze boxes tensor"); - - FailureOr squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, scores); + FailureOr squeezedScores = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value()); + squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, + squeezedScores.value()); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); - boxes = squeezedBoxes.value(); scores = squeezedScores.value(); @@ -3720,61 +3732,103 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Filter out the boxes if the score < score_threshold if (operands.size() == 5) { Value scoreThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), - operands[4]); + loc, rewriter.getType(), operands[4]); Value minScores = rewriter.create( - binder.getLoc(), + loc, Torch::ValueTensorType::get(binder.op->getContext(), SmallVector{}, rewriter.getF32Type()), scores); minScores = rewriter.create( - binder.getLoc(), rewriter.getType(), minScores); + loc, rewriter.getType(), minScores); Value scoresCond = rewriter.create( - binder.getLoc(), minScores, scoreThreshold); + loc, minScores, scoreThreshold); rewriter.create( - binder.getLoc(), scoresCond, + loc, scoresCond, rewriter.getStringAttr( "unimplemented: score_threshold should be <= min(scores)")); } - // TODO: Support default iou_threshold - Value iouThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[3]); + // Get max_output_boxes_per_class and iou_threshold + Value cst0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value maxOutputBoxesPerClass = cst0; + Value iouThreshold = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0)); + if (operands.size() > 3 && + !isa(operands[3].getType())) { + iouThreshold = rewriter.create( + loc, rewriter.getType(), operands[3]); + } + if (operands.size() > 2 && + !isa(operands[2].getType())) { + maxOutputBoxesPerClass = rewriter.create( + loc, rewriter.getType(), operands[2]); + } + auto nmsTy = Torch::ValueTensorType::get( + binder.op->getContext(), SmallVector{-1}, + rewriter.getIntegerType(64, /*signed=*/true)); + Value result = rewriter.create( + loc, nmsTy, boxes, scores, iouThreshold); + + // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class + Value numOutputBoxes = + rewriter.create(loc, result, cst0); + Value boxesCond = rewriter.create( + loc, numOutputBoxes, maxOutputBoxesPerClass); + + auto nmsResultTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{resultType.getSizes()[0]}, rewriter.getIntegerType(64, /*signed=*/true)); - Value result = rewriter.create( - binder.getLoc(), nmsTy, boxes, scores, iouThreshold); + auto ifSlice = rewriter.create( + loc, TypeRange({nmsResultTy}), boxesCond); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getThenRegion(), + ifSlice.getThenRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, + /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); + rewriter.create(loc, curResult); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getElseRegion(), + ifSlice.getElseRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result); + rewriter.create(loc, curResult); + } + result = ifSlice.getResult(0); // The result generated by torchvision.nms op is of shape [n], while the // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor // and make it of shape [n, 1] and then concatenate it with a zero // tensor of shape [n, 2] to make it of shape [n, 3]. - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, dim); + Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); if (failed(unsqueezedResult)) return rewriter.notifyMatchFailure( binder.op, "failed to unsqueeze result tensor"); result = unsqueezedResult.value(); - Value numOutputBoxes = rewriter.create( - binder.getLoc(), result, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0))); + numOutputBoxes = + rewriter.create(loc, result, cst0); SmallVector zerosShapeValues{numOutputBoxes}; zerosShapeValues.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2))); + loc, rewriter.getI64IntegerAttr(2))); Value zerosShapeList = rewriter.create( - binder.getLoc(), + loc, rewriter.getType( rewriter.getType()), zerosShapeValues); - std::optional> resultShape = cast(result.getType()).getOptionalSizes(); if (!resultShape.has_value()) @@ -3783,10 +3837,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( llvm::SmallVector zerosShape = {resultShape->front(), 2}; auto zerosTy = Torch::ValueTensorType::get( resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = rewriter.create(loc); Value zeros = rewriter.create( - binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone, - cstNone); + loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); Type listElemType = cast(resultType) @@ -3794,22 +3847,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value tensorList = rewriter.create( - binder.getLoc(), listType, SmallVector{zeros, result}); - - // TODO: Support max_output_boxes_per_class input - // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class - Value maxOutputBoxesPerClass = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[2]); - Value boxesCond = rewriter.create( - binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass); - rewriter.create( - binder.getLoc(), boxesCond, - rewriter.getStringAttr( - "unimplemented: number of output boxes per class should be " - "<= max_output_boxes_per_class")); - + loc, listType, SmallVector{zeros, result}); rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, dim); + tensorList, cst1); return success(); }); } diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a18c0bae01fc..b8c20bc73f65 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Value input = adaptor.getSelf(); - auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); - - if (inputRank == 0) { - return rewriter.notifyMatchFailure( - op, "zero input rank should have been handled by the folder"); - } - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - - // assert dynamic squeeze dim size == 1 - if (inputType.isDynamicDim(dim)) { - Value cstDim = rewriter.create(op.getLoc(), dim); - Value dimVal = rewriter.create(op.getLoc(), input, cstDim); - Value cstOne = rewriter.create(op.getLoc(), 1); - Value cmp = rewriter.create( - op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); - rewriter.create( - op.getLoc(), cmp, - rewriter.getStringAttr( - "Expected dynamic squeeze dim size to be statically 1")); - } - - const TypeConverter *typeConverter = getTypeConverter(); - auto resultType = - cast(typeConverter->convertType(op.getType())); - int64_t resultRank = resultType.getRank(); - // If the dim(th) dimension of operand tensor type is not statically unit, - // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { - rewriter.replaceOpWithNewOp(op, resultType, input); - return success(); + auto squeezeTensorInfo = + squeezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(squeezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - SmallVector reassociationMap(resultRank); - bool alreadyCrossedSqueezedDim = false; - for (int i = 0; i != resultRank; i++) { - if (alreadyCrossedSqueezedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (dim != 0 && i != dim - 1) - continue; - - alreadyCrossedSqueezedDim = true; - if (dim == 0) - reassociationMap[0].push_back(1); - if (i == dim - 1) - reassociationMap[i].push_back(dim); - } - } - // Note: In case the operand tensor type is of unit rank and is statically - // shaped with unit dimension, the `reassociationMap` will be empty and the - // input will be collapsed to a 0-D tensor. - rewriter.replaceOpWithNewOp(op, resultType, input, - reassociationMap); + rewriter.replaceOp(op, squeezeTensorInfo.value()); return success(); } }; @@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - auto inputRank = - cast(adaptor.getSelf().getType()).getRank(); - dim = toPositiveDim(dim, inputRank + 1); - if (!isValidDim(dim, inputRank + 1)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector reassociationMap(inputRank); - // From the perspective of the reassociation map, the situation of - // unsqueezing before or after the last dimension is symmetrical. - // Normalize it to the "before" case. - // The 0 case is special here, since there is no last dimension to insert - // before -- we simply rely on the loop below iterating 0 times. - if (dim == inputRank && inputRank != 0) - dim = inputRank - 1; - bool alreadyCrossedExpandedDim = false; - for (int i = 0; i != inputRank; i++) { - if (alreadyCrossedExpandedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (i == dim) { - reassociationMap[i].push_back(i + 1); - alreadyCrossedExpandedDim = true; - } - } + auto unsqueezeTensorInfo = + unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(unsqueezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - auto resultType = cast( - getTypeConverter()->convertType(op->getResult(0).getType())); - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getSelf(), reassociationMap); + + rewriter.replaceOp(op, unsqueezeTensorInfo.value()); return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9ec7761704ea..9073c5846f33 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -727,15 +727,21 @@ class ConvertAtenBmmOp : public OpConversionPattern { // Check the matrixs shapes are valid for mulplication. checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); + Type accumulatorDType = getDefaultAccType(rewriter, resultElementType); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, - resultElementType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, accumulatorDType); Value bmm = rewriter .create(loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0) .getResult(0); + + if (accumulatorDType != resultElementType) { + bmm = torch_to_linalg::convertTensorToElementType(rewriter, loc, bmm, + resultElementType); + } + rewriter.replaceOpWithNewOp(op, newResultType, bmm); return success(); } @@ -850,6 +856,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int dilations"); + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); + + // Adding support for 1d group convolution by converting the 1d-conv to + // 2d-conv. + // TODO: Replace this logic with the appropriate linalg op for 1-d group + // convolution once that support is added. + bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1); + if (is1DGroupConv) { + // Unsqueezing the last dim of input and weight. Also extending the + // dilation, stride, padding, and output padding lists. + auto unsqueezeInputInfo = + unsqueezeTensor(rewriter, op, input, /*dim=*/-1); + if (failed(unsqueezeInputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + input = unsqueezeInputInfo.value(); + + auto unsqueezeWeightInfo = + unsqueezeTensor(rewriter, op, weight, /*dim=*/-1); + if (failed(unsqueezeWeightInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + weight = unsqueezeWeightInfo.value(); + + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + paddingIntValues.push_back(cstZero); + outputPaddingIntValues.push_back(cstZero); + strideInts.push_back(1); + dilationInts.push_back(1); + + inRank++; + numSpatialDims++; + } + Value inBatch = getDimOp(rewriter, loc, input, 0); Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; @@ -861,13 +909,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Checks for valid group size - int64_t numGroups; - if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) - return rewriter.notifyMatchFailure(op, - "only constant group size supported."); - Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); - auto validate = [&](Value toValidate, std::string err) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); @@ -1280,13 +1321,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } + rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } if (numSpatialDims != 2) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); + op, "unimplemented: only 1D and 2D grouped convolution supported"); // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { @@ -1371,6 +1423,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index cf41bbcd711b..98dbc1957892 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -116,6 +116,22 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, else division = b.createOrFold(loc, dividend, strideInt); Value out = b.createOrFold(loc, division, c1); + + if (ceilMode) { + Value outMinusOneTimesStride = + b.createOrFold(loc, division, strideInt); + Value inAddLeftPadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), paddingInt); + + auto reduceOutputDimCond = + b.createOrFold(loc, arith::CmpIPredicate::uge, + outMinusOneTimesStride, inAddLeftPadding); + + auto reducedDim = b.createOrFold(loc, reduceOutputDimCond, + division, out); + return castIntToIndex(b, loc, reducedDim); + } + return castIntToIndex(b, loc, out); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 9572723fdd29..1c2f7d6f2a11 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5398,9 +5398,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { } else { int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; - if (ceilMode && (dimSize % stride != 0)) - return dimSize / stride + 2; - return dimSize / stride + 1; + int64_t outputDim = dimSize / stride + 1; + if (ceilMode && (dimSize % stride != 0) && + (outputDim * stride < inputDim + padBefore)) + outputDim++; + return outputDim; } } @@ -8256,6 +8258,198 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.unfold +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnfoldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Approach: Use GatherOp to retrieve target elements from target dim and then + // reshape the output into slices according to the output shape + // + // Lowering steps: + // 1. Create PyTorch-style indices tensor corresponding to target elements and + // reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1)) + // with d_x being the dimension size of the input at dim x. + // The indices vector will be calculated using the following formula: + // for i in range(d_0 * d_1 * ... * d_(target_dim - 1)): + // for window in range(nWindows): + // for elementIndex in range(size): + // for j in range(d_(target_dim + 1) * ... * d_(rank-1)): + // indices_vec.push_back(elementIndex + window * step) + // 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices + // 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + // 4. Reshape result from above to correct output shape + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + int64_t dim; + if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Only constant int dims are supported"); + + int64_t size; + if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) + return rewriter.notifyMatchFailure(op, + "Only constant int sizes are supported"); + + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure(op, + "Only constant int steps are supported"); + + if (step <= 0) + return rewriter.notifyMatchFailure(op, "Step value must be greater than 0"); + + // Handle rank zero + if (selfRank == 0) { + if (dim != 0) + return rewriter.notifyMatchFailure( + op, "Unsupported dim value for rank zero input"); + + if (size != 1) + return rewriter.notifyMatchFailure( + op, "Unsupported size value for rank zero input"); + + auto result = rewriter.create( + op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, + rewriter.getDenseI64ArrayAttr({1})); + + rewriter.replaceOp(op, {result.getResult()}); + return success(); + } + + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Dim value is invalid"); + + // Size of dimension 'dim' in the returned tensor (or number of windows within + // the dimension that got sliced) + int64_t nWindows = (selfShape[dim] - size) / step + 1; + + // Find number of times that each base index value gets repeated for target + // dim based on dim values before and after target dim i.e. preDimAccumulate = + // d_0 * d_1 * ... * d_(target_dim - 1) + // postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1) + int64_t preDimAccumulate = + std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1, + std::multiplies()); + int64_t postDimAccumulate = + std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1, + std::multiplies()); + + // Calculate PyTorch-style gather indices vector + // Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1 + // -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2 + // pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + // pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3, + // 0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + SmallVector pyTorchIndicesBaseVec; + SmallVector pyTorchIndicesVec; + + for (int64_t window = 0; window < nWindows; window++) { + for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) { + int32_t baseIndex = static_cast(elementIndex + window * step); + for (int64_t i = 0; i < postDimAccumulate; i++) + pyTorchIndicesBaseVec.push_back(baseIndex); + } + } + + for (int64_t i = 0; i < preDimAccumulate; i++) + pyTorchIndicesVec.insert(pyTorchIndicesVec.end(), + pyTorchIndicesBaseVec.begin(), + pyTorchIndicesBaseVec.end()); + + // Create the PyTorch-style indices tensor + // Continuing with the previous example: + // pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3) + // pyTorchIndices = tensor([[[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]], + // [[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]]]) + SmallVector pyTorchIndicesShape(selfShape); + pyTorchIndicesShape[dim] = nWindows * size; + auto pyTorchIndices = + tosa::getConstTensor(rewriter, op, pyTorchIndicesVec, + pyTorchIndicesShape) + .value(); + + // Convert PyTorch-style indices to TensorFlow-style indices + auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self, + pyTorchIndices, dim); + if (!tfIndices) + return rewriter.notifyMatchFailure(op, + "Convert PyTorch-style indices and dim " + "to TensorFlow-style indices failed"); + + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + auto gatherNdOp = tosa::convertGatherNdOp( + rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy), + self, tfIndices.value()); + if (!gatherNdOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + // Reshape to an intermediary shape where the gathered elements in dimension + // 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size' + SmallVector intermediaryShape; + for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) { + if (currentDim == dim) { + intermediaryShape.push_back(nWindows); + intermediaryShape.push_back(size); + } else { + intermediaryShape.push_back(pyTorchIndicesShape[currentDim]); + } + } + + auto reshapeOp = rewriter.create( + op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy), + gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape)); + + // Permute dims to the correct result order + SmallVector permutedDims; + for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) { + if (currentDim != dim + 1) + permutedDims.push_back(static_cast(currentDim)); + } + permutedDims.push_back(static_cast(dim + 1)); + + auto permutedDimsConst = tosa::getConstTensor( + rewriter, op, + /*vec=*/permutedDims, + /*shape=*/{static_cast(selfRank + 1)}) + .value(); + + auto result = rewriter.create( + op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -8277,342 +8471,357 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { ConversionTarget target(*context); target.addLegalDialect(); + target.addIllegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - // The following ops are never the primary reason why lowering fails. - // The backend contract only allows functions to return tensors thus there - // is always another op using them. - // When we have a chain of torch.constant.int followed by a unsupported - // torch op, we want the pass to mention the unsupported torch op - // in the error message. - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalDialect(); + populateTorchToTosaConversionLegalOps(target); RewritePatternSet patterns(context); + auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps( + typeConverter, patterns); + + for (auto op : illegalOps) { + target.addIllegalOp(OperationName(op, context)); + } + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { + // The following ops are never the primary reason why lowering fails. + // The backend contract only allows functions to return tensors thus there + // is always another op using them. + // When we have a chain of torch.constant.int followed by a unsupported + // torch op, we want the pass to mention the unsupported torch op + // in the error message. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); +} + +std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + + MLIRContext *context = patterns.getContext(); + std::set illegalOps; + #define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN #define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) - INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) - INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) - INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) - INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) - INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) - INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) - INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) + INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) + INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) + INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) + INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) + INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) - INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) - INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) - INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) - INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) - INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, - tosa::LogicalLeftShiftOp) - INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, - tosa::ArithmeticRightShiftOp) + INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) + INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) + INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_BINARY_MUL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); - INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); #undef INSERT_BINARY_MUL_PATTERN #define INSERT_BINARY_DIV_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); #undef INSERT_BINARY_DIV_PATTERN #define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); #undef INSERT_REMAINDER_FMOD_OP_PATTERN #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, - mlir::tosa::convertReduceMeanOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, - mlir::tosa::convertReduceSumOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, - mlir::tosa::convertLinalgVectorNormOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, + mlir::tosa::convertReduceMeanOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, + mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, - mlir::tosa::convertReduceAllOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, - mlir::tosa::convertReduceProdOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, + mlir::tosa::convertReduceAnyOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, + mlir::tosa::convertReduceAllOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ONEDIM_REDUCTION_OP_PATTERN #define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, - mlir::tosa::convertReduceAllOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, - mlir::tosa::convertReduceSumOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, - mlir::tosa::convertReduceMaxOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, - mlir::tosa::convertReduceMinOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, - mlir::tosa::convertReduceProdOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN #define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); #undef INSERT_INDICES_REDUCTION_OP_PATTERN #define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) #undef INSERT_SQUEEZE_OP_PATTERN #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); #undef INSERT_MATMUL_ATEMOP_PATTERN #define INSERT_MM_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MM_ATENOP_PATTERN(AtenMmOp); - INSERT_MM_ATENOP_PATTERN(AtenBmmOp); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); #undef INSERT_MM_ATEMOP_PATTERN #define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN #define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, - tosa::AvgPool2dOp); + INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, + tosa::AvgPool2dOp); #undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool1dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool1dOp::getOperationName()); + patterns.add(typeConverter, context); #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); - INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); - INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_FILL_PATTERN(AtenFill_ScalarOp); - INSERT_FILL_PATTERN(AtenFillScalarOp); - INSERT_FILL_PATTERN(AtenFillTensorOp); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); #undef INSERT_FILL_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); #undef INSERT_MASKED_FILL_PATTERN #define INSERT_POW_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); - INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); - INSERT_POW_OP_PATTERN(AtenPowScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); #undef INSERT_POW_OP_PATTERN +#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); +#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN + #define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); #undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN -#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); -#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN - #define INSERT_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenReluOp); - INSERT_ATENOP_PATTERN(AtenLeakyReluOp); - INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenRsubScalarOp); - INSERT_ATENOP_PATTERN(AtenConvolutionOp); - INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); - INSERT_ATENOP_PATTERN(AtenReshapeOp); - INSERT_ATENOP_PATTERN(AtenBatchNormOp); - INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); - INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); - INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); - INSERT_ATENOP_PATTERN(AtenPermuteOp); - INSERT_ATENOP_PATTERN(AtenLog2Op); - INSERT_ATENOP_PATTERN(AtenThresholdOp); - INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); - INSERT_ATENOP_PATTERN(AtenDropoutOp); - INSERT_ATENOP_PATTERN(AtenViewOp); - INSERT_ATENOP_PATTERN(AtenGeluOp); - INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); - INSERT_ATENOP_PATTERN(AtenEmbeddingOp); - INSERT_ATENOP_PATTERN(AtenTransposeIntOp); - INSERT_ATENOP_PATTERN(AtenSliceTensorOp); - INSERT_ATENOP_PATTERN(AtenBroadcastToOp); - INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenAbsOp); - INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenClampOp); - INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); - INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenCopyOp); - INSERT_ATENOP_PATTERN(AtenToDtypeOp); - INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); - INSERT_ATENOP_PATTERN(AtenCatOp); - INSERT_ATENOP_PATTERN(AtenSqrtOp); - INSERT_ATENOP_PATTERN(AtenIscloseOp); - INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); - INSERT_ATENOP_PATTERN(AtenTrilOp); - INSERT_ATENOP_PATTERN(AtenDiagonalOp); - INSERT_ATENOP_PATTERN(AtenIndexSelectOp); - INSERT_ATENOP_PATTERN(AtenFlipOp); - INSERT_ATENOP_PATTERN(AtenRoundOp); - INSERT_ATENOP_PATTERN(AtenScatterSrcOp); - INSERT_ATENOP_PATTERN(AtenSliceScatterOp); - INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); - INSERT_ATENOP_PATTERN(AtenUniformOp); - INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); - INSERT_ATENOP_PATTERN(AtenAsStridedOp); - INSERT_ATENOP_PATTERN(AtenClampTensorOp); - INSERT_ATENOP_PATTERN(PrimsCollapseOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); - INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); - INSERT_ATENOP_PATTERN(PrimsSplitDimOp); - INSERT_ATENOP_PATTERN(AtenOuterOp); - INSERT_ATENOP_PATTERN(AtenLogitOp); - INSERT_ATENOP_PATTERN(AtenLog1pOp); - INSERT_ATENOP_PATTERN(AtenLog10Op); - INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenLeakyReluOp); + INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); + INSERT_ATENOP_PATTERN(AtenConvolutionOp); + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReshapeOp); + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenThresholdOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenDropoutOp); + INSERT_ATENOP_PATTERN(AtenViewOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenTransposeIntOp); + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenAbsOp); + INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); + INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); + INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); + INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); + INSERT_ATENOP_PATTERN(AtenOuterOp); + INSERT_ATENOP_PATTERN(AtenLogitOp); + INSERT_ATENOP_PATTERN(AtenLog1pOp); + INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenUnfoldOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); #undef INSERT_CLONE_ATENOP_PATTERN - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace + return illegalOps; +} std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index e3f5b6d0299a..72217e5f4afd 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -447,6 +447,119 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, return castIntToIndex(rewriter, loc, boundedByDimSize); } +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + ArrayRef inputShape = inputType.getShape(); + + // `input` has a reduced rank. Hence add 1. + int64_t unsqueezedRank = inputShape.size() + 1; + dim = toPositiveDim(dim, unsqueezedRank); + if (!isValidDim(dim, unsqueezedRank)) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + + SmallVector unsqueezedShape{inputShape}; + unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1); + Type unsqueezedType = + RankedTensorType::get(unsqueezedShape, inputType.getElementType()); + + SmallVector reassociationMap(inputRank); + // From the perspective of the reassociation map, the situation of + // unsqueezing before or after the last dimension is symmetrical. + // Normalize it to the "before" case. + // The 0 case is special here, since there is no last dimension to insert + // before -- we simply rely on the loop below iterating 0 times. + if (dim == inputRank && inputRank != 0) + dim = inputRank - 1; + bool alreadyCrossedExpandedDim = false; + for (int i = 0; i != inputRank; i++) { + if (alreadyCrossedExpandedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (i == dim) { + reassociationMap[i].push_back(i + 1); + alreadyCrossedExpandedDim = true; + } + } + } + Value unsqueezed = rewriter.create( + op->getLoc(), unsqueezedType, input, reassociationMap); + return unsqueezed; +} + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + Location loc = op->getLoc(); + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + // No scope for squeezing the input. + if (inputRank == 0) + return input; + + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + // assert dynamic squeeze dim size == 1 + if (inputType.isDynamicDim(dim)) { + Value cstDim = rewriter.create(loc, dim); + Value dimVal = rewriter.create(loc, input, cstDim); + Value cstOne = rewriter.create(loc, 1); + Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, + dimVal, cstOne); + rewriter.create( + loc, cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); + } + + ArrayRef inputShape = inputType.getShape(); + SmallVector squeezedShape; + squeezedShape.append(inputShape.begin(), inputShape.begin() + dim); + squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end()); + int64_t squeezedRank = inputRank - 1; + Type squeezedType = + RankedTensorType::get(squeezedShape, inputType.getElementType()); + + // If the dim(th) dimension of operand tensor type is not statically unit, + // squeeze will behave as an identity operation. + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { + return input; + } + + SmallVector reassociationMap(squeezedRank); + bool alreadyCrossedSqueezedDim = false; + for (int i = 0; i != squeezedRank; i++) { + if (alreadyCrossedSqueezedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (dim != 0 && i != dim - 1) + continue; + + alreadyCrossedSqueezedDim = true; + if (dim == 0) + reassociationMap[0].push_back(1); + if (i == dim - 1) + reassociationMap[i].push_back(dim); + } + } + // Note: In case the operand tensor type is of unit rank and is statically + // shaped with unit dimension, the `reassociationMap` will be empty and the + // input will be collapsed to a 0-D tensor. + Value squeezed = rewriter.create( + op->getLoc(), squeezedType, input, reassociationMap); + return squeezed; +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 933543c18aaf..1fd7102e2ef4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6495,6 +6495,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.special_expm1\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -11581,6 +11585,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.special_expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -12574,17 +12583,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %true = torch.constant.bool true\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If.yield %3 : !torch.bool\n" -" }\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1caac461fe8b..5f963c9e0386 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3738,11 +3738,7 @@ class DecomposeAtenRreluOp : public OpRewritePattern { // Create a uniform random op with low and high set to `lower` and // `upper`, respectively. Value none = rewriter.create(loc); - Value emptyTensor = rewriter.create( - loc, resType, self, constantZeroFloat, /*dtype=*/none, - /*layout=*/none, - /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); - alpha = rewriter.create(loc, resType, emptyTensor, + alpha = rewriter.create(loc, resType, self, /*from=*/lower, /*to=*/upper, /*generator=*/none); } else { @@ -3774,6 +3770,7 @@ class DecomposeAtenRreluOp : public OpRewritePattern { }; } // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { @@ -5576,6 +5573,240 @@ class DecomposeAtenConvolutionBackwardOp }; } // namespace +/** + * # one dim input + * t = torch.tensor([0, 0, 1, 1, 0, 0] + * # t_flat:[0, 0, 1, 1, 0, 0] + * t_flat = t.flatten(0, 0) + * nonzero_mask = t_flat != 0 + * # nonzero_mask:[0, 0, 1, 1, 0, 0] + * nonzero_mask = nonzero_mask.long() + * # destination_indices:[-1, -1, 0, 1, 1, 1] + * destination_indices = torch.cumsum(nonzero_mask, 0) - 1 + * # destination_indices_clamp:[0, 0, 0, 1, 1, 1] + * destination_indices_clamp = torch.clamp(destination_indices, min=0) + * # iota:[0, 0, 2, 3, 0, 0] + * iota = torch.arange(t_flat.size(0)) * nonzero_mask + * # scatter_self:[0, 0, 0, 0, 0, 0] + * scatter_self = torch.zeros_like(t_flat, dtype=torch.int64) + * # compacted:[2, 3, 0, 0, 0, 0] + * compacted = torch.scatter_add( + * scatter_self, dim=0, index=destination_indices_clamp, src=iota + * ) + * # result_flat:[2, 3] + * result_flat = compacted[: torch.sum(nonzero_mask)] + * + * # multi dim support + * original_shape = t.shape + * # input_shape_tensor:[6] + * input_shape_tensor = torch.tensor(original_shape) + * strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) + * + * one = torch.tensor([1]) + * if(t.dim() > 1): + * slicedStrides = strides[1:-1] + * strides = torch.cat([slicedStrides, one]) + * else: + * strides = one + * # a: tensor([[2], [3]]) torch.Size([2, 1]) + * a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1]) + * # b: tensor([[1]]) torch.Size([1, 1]) + * b = strides.unsqueeze(0) + * # c: tensor([[2], [3]]) torch.Size([2, 1]) + * c = a // b + * # result: tensor([[2], [3]]) torch.Size([2, 1]) + * result = c % input_shape_tensor + */ +class DecomposeAtenNonzeroOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNonzeroOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resultType = cast(op.getType()); + auto intType = resultType.getDtype(); + Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType); + auto constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + std::function makeOneElementList = [&](Value element) { + auto listType = Torch::ListType::get(element.getType()); + return rewriter.create(loc, listType, + ArrayRef{element}); + }; + + Value input = op.getSelf(); + auto inputType = dyn_cast(input.getType()); + int64_t inputRank = inputType.getSizes().size(); + + // t_flat = t.flatten() # torch.flatten(t, 0, 0) + int64_t flattenedSize = 1; + if (inputType.hasSizes()) { + for (auto size : inputType.getSizes()) { + flattenedSize *= size; + } + } else { + flattenedSize = kUnknownSize; + } + + auto flattendInputShape = SmallVector{flattenedSize}; + auto flattenedInputType = rewriter.getType( + flattendInputShape, inputType.getOptionalDtype()); + + // %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 : + auto inputDimsEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank - 1)); + Value flattenedInput = rewriter.create( + loc, flattenedInputType, input, constantZero /*inputDimsStart*/, + inputDimsEnd /*inputDimsEnd*/); + + // nonzero_mask = (t_flat != 0) + auto boolMaskType = inputType.getWithSizesAndDtype( + flattenedInputType.getOptionalSizes(), rewriter.getI1Type()); + Value boolMask = rewriter.create( + loc, boolMaskType, flattenedInput, constantZero); + + // nonzero_mask = nonzero_mask.int() + Value falseCst = rewriter.create(loc, false); + Value noneCst = rewriter.create(loc); + auto intMaskType = flattenedInputType.getWithSizesAndDtype( + flattenedInputType.getOptionalSizes(), intType); + Value intMask = rewriter.create( + loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst); + + // destination_indices = torch.cumsum(nonzero_mask, 0) - 1 + Value cumulativeSum = rewriter.create( + loc, intMaskType, intMask, constantZero, noneCst); + Value subtracted = rewriter.create( + loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne); + + // destination_indices = torch.clamp(destination_indices, min=0) + Value indices = rewriter.create(loc, intMaskType, + subtracted, constantZero); + + // iota = torch.arange(len(t_flat)) * nonzero_mask + Value end = rewriter.create(loc, flattenedInput, + /*dim=*/constantZero); + Value rangeTensor = rewriter.create( + loc, intMaskType, /*start*/ constantZero, /*end*/ end, + /*step*/ constantOne, noneCst, noneCst, noneCst, noneCst); + Value multiplied = rewriter.create(loc, intMaskType, + rangeTensor, intMask); + + // scatter_self = torch.zeros_like(t, dtype=torch.int64) + // AtenFullLike doesn't support index type so we have to use int. + Value zerosTensor = rewriter.create( + loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst, + noneCst, noneCst); + + // compacted = torch.scatter_add( + // scatter_self, dim=0, index=destination_indices_clamp, src=iota) + Value scatteredTensor = rewriter.create( + loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero, + /*index=*/indices, /*src=*/multiplied); + + // result_flat = compacted[:torch.sum(nonzero_mask)] + auto scalarType = ValueTensorType::get(rewriter.getContext(), + ArrayRef{}, intType); + Value sumMask = + rewriter.create(loc, scalarType, intMask, noneCst); + Value numNonzero = rewriter.create(loc, sumMask); + + auto slicedResultType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize}, intType); + Value slicedResult = + rewriter.create(loc, slicedResultType, + /*self=*/scatteredTensor, + /*dim=*/constantZero, + /*start=*/noneCst, + /*end=*/numNonzero, + /*step=*/constantOne); + + // TODO fix multidim dynamic support. The following code only work for + // static multidim. Convert flattened indices back to multi-dimensional + // indices original_shape = t.shape input_shape_tensor = + // torch.tensor(original_shape) + auto shapeType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{inputRank}, intType); + SmallVector shapeValues; + for (int i = 0; i < inputRank; i++) { + auto constantI = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + Value shape = rewriter.create(loc, input, + /*dim=*/constantI); + shapeValues.push_back(shape); + } + Value shapeTensorList = rewriter.create( + loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues); + Value inputShapeTensor = rewriter.create( + loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst); + + // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0) + Value flippedShape = rewriter.create( + loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); + Value cumulativeProduct = rewriter.create( + loc, shapeType, flippedShape, constantZero, noneCst); + Value flippedCumulativeProduct = rewriter.create( + loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); + + // strides = torch.cat([strides[1:-1], torch.tensor([1])]) + auto oneTensorType = ValueTensorType::get(rewriter.getContext(), + SmallVector{1}, intType); + Value oneTensor = rewriter.create( + loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst, + noneCst); + + Value strides; + if (inputRank > 1) { + // strides[1:-1] + auto slicedStrideType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{inputRank - 1}, // sizes + intType); + Value strideSliceEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank)); + Value slicedStrides = rewriter.create( + loc, slicedStrideType, /*self*/ flippedCumulativeProduct, + /*dim*/ constantZero, + /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); + // torch.cat + auto tensorListElementType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize}, intType); + Value tensorList = rewriter.create( + loc, Torch::ListType::get(tensorListElementType), + SmallVector{slicedStrides, oneTensor}); + strides = rewriter.create(loc, shapeType, tensorList, + constantZero); + } else { + // strides[1:-1] is empty + strides = oneTensor; + } + + // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % + // input_shape_tensor + auto unsqueezedResultType = ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize, 1}, intType); + Value unsqueezedResult = rewriter.create( + loc, unsqueezedResultType, slicedResult, constantOne); + + auto unsqueezedStridesType = ValueTensorType::get( + rewriter.getContext(), SmallVector{1, inputRank}, intType); + Value unsqueezedStrides = rewriter.create( + loc, unsqueezedStridesType, strides, constantZero); + + auto dividedBroadcastType = ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize, inputRank}, + intType); + Value divided = rewriter.create( + loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides); + + Value modded = rewriter.create( + loc, resultType, divided, inputShapeTensor); + + rewriter.replaceOp(op, modded); + return success(); + } +}; + // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { @@ -11048,6 +11279,19 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenSpecialExpm1Op + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSpecialExpm1Op op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11121,6 +11365,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -11330,6 +11575,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ce675b2f9301..ea9c2d014aca 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -567,6 +567,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index c9638c8353b1..d9d7ef1a0cd4 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -52,7 +53,8 @@ void mlir::torch::registerOptionalInputDialects( mlir::DialectRegistry ®istry) { registry.insert(); + scf::SCFDialect, sparse_tensor::SparseTensorDialect, + tensor::TensorDialect, tosa::TosaDialect>(); } void mlir::torch::registerAllPasses() { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d4d3159c6fc1..af44ee6cf7c7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -398,7 +398,7 @@ "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", "AtenIntMM_basic", - "AtenItemFpOpModule_basic", + "AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", "QuantizedReluInt32_basic", @@ -424,7 +424,6 @@ "CumsumModule_basic", "CumprodModule_basic", "DeformConv2D_basic", - "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", @@ -438,7 +437,6 @@ "IntFloatModule_basic", "IntImplicitModule_basic", "LenStrModule_basic", - "MulFloatModule_basic", "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", "NllLossModuleBackward1DMeanWeight_basic", @@ -463,17 +461,11 @@ "QuantizedSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ScalarImplicitFloatModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", "SplitDimDynamicModule_basic", "SplitDimStaticModule_basic", "SqrtIntModule_basic", - "SubFloatModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", - "TensorToFloatZeroRank_basic", - "TensorToFloat_basic", "ThresholdBackward2dMixedModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dDynamicFactor_basic", @@ -502,32 +494,18 @@ "AdaptiveMaxPool1dStatic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", - "IndexPutImpl1DFloatAccumulateModule_basic", - "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl1DIntAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexPutImpl2DFloatNonAccumulateModule_basic", - "IndexPutImpl2DImplicitModule_basic", - "IndexPutImpl2DIndexModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", - "IndexPutImpl3DFloatNonAccumulateModule_basic", - "IndexPutImplIndexWithNoneModule_basic", "IsInfiniteModule_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "OneHotModule_basic", - # RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSignbitModule_basic", "ElementwiseCopysignModule_basic", + "BernoulliFloatModule_basic", + "BernoulliTensorModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -539,9 +517,6 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", - # torch export: RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -641,6 +616,7 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", + "AtenNonzero1DDynamicModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenTopKModule_basic", @@ -759,6 +735,7 @@ "LenStrModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", "MaxPool2dWithIndicesBackwardStatic3DModule_basic", @@ -856,8 +833,6 @@ "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", "SortTensorDescending_basic", "SortTensorInteger_basic", "SortTensorNegativeDimension_basic", @@ -922,8 +897,6 @@ "AtenItemIntOpModule_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "IscloseStaticModuleTrue_basic", @@ -932,7 +905,6 @@ "MeshgridIndexingXY_basic", "Meshgrid_basic", "MulIntModule_basic", - "OneHotModule_basic", "ReduceFrobeniusNormComplexModule_basic", "ScalarImplicitIntModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", @@ -950,11 +922,9 @@ "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", - # RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "BernoulliFloatModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -976,9 +946,8 @@ "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", - # torch export: RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", } STABLEHLO_PASS_SET = { @@ -1216,6 +1185,8 @@ "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", @@ -1705,6 +1676,8 @@ "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", @@ -1713,6 +1686,9 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_basic", "ElementwiseErfIntModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseSigmoidIntModule_basic", @@ -2275,6 +2251,7 @@ "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dStaticModule_basic", "MeanModule_basic", "MmDagModule_basic", @@ -2731,6 +2708,7 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -2886,6 +2864,7 @@ "Conv1dModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -2949,6 +2928,8 @@ "ElementwiseEluNonDefaultModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseCreateComplexModule_basic", "ElementwiseMulTensorComplexModule_basic", @@ -3010,7 +2991,6 @@ "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", - "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", @@ -3380,6 +3360,13 @@ "ScaledDotProductAttentionBoolMaskModule_basic", } +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } + if torch_version_for_comparison() < version.parse("2.4.0.dev"): STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { "AtenIntMM_basic", @@ -3427,6 +3414,8 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "UniformModule_basic", + "UniformStaticShapeModule_basic", "AtenFftRfft2DLastDim_basic", "AtenFftRfft2DMiddleDim_basic", "IsInfiniteModule_basic", @@ -3438,19 +3427,13 @@ "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", "MaxPool3dModule_basic", "MaxPool3dStaticModule_basic", "ViewDtypeStaticModule_basic", - "Unfold_Module_Dynamic_basic", - "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_Size_Zero_basic", - "Unfold_Module_Rank_Zero_basic", - "Unfold_Module_basic", "ArangeZeroElementOutputModule_basic", "NumpyTRank0Module_basic", "Permute0RankModule_basic", @@ -3582,6 +3565,7 @@ "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -3651,6 +3635,8 @@ "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseWhereScalarOtherStaticModule_basic", @@ -3871,17 +3857,10 @@ "AdaptiveAvgPool2dDynamic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", - "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexPutImpl2DFloatNonAccumulateModule_basic", - "IndexPutImpl3DFloatNonAccumulateModule_basic", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "OneHotModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", @@ -4173,6 +4152,7 @@ "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -4344,6 +4324,8 @@ "ElementwiseSinIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseTanIntModule_basic", @@ -4937,3 +4919,10 @@ "_LogSoftmaxModule_basic", "_SoftmaxModule_basic", } + +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_TOSA_XFAIL_SET = ONNX_TOSA_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 012833b64c8a..98dfb14627fc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -222,6 +222,9 @@ def aten〇exp2〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇special_expm1〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇isfinite〡shape(self: List[int]) -> List[int]: return self @@ -2711,6 +2714,11 @@ def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇special_expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool @@ -3449,10 +3457,9 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype - assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 698fec575749..a45653f04cf7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -451,6 +451,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) + emit("aten::special_expm1 : (Tensor) -> (Tensor)") emit_with_mutating_variants( "aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True ) @@ -1208,6 +1209,7 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 5e3aa3bc02f6..927bfe85df8a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -87,6 +87,29 @@ def BmmFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) +class BmmFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float16, True), + ([-1, -1, -1], torch.float16, True), + ] + ) + def forward(self, lhs, rhs): + return torch.bmm(lhs, rhs) + + +@register_test_case(module_factory=lambda: BmmFloat16Module()) +def BmmFloat16Module_basic(module, tu: TestUtils): + module.forward( + tu.rand(3, 4, 5).to(torch.float16), tu.rand(3, 5, 4).to(torch.float16) + ) + + class BmmIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -6407,3 +6430,26 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils): module.forward( tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64) ) + + +# ============================================================================== + + +class AtenNonzero1DDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.ops.aten.nonzero(x) + + +@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule()) +def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 7a45dd7fc0ce..663c4b6a746b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1199,6 +1199,33 @@ def Conv1dWithValidPaddingModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dGroupModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d( + inputVec, weight, bias=bias, stride=[1], padding=[0], dilation=[1], groups=2 + ) + + +@register_test_case(module_factory=lambda: Conv1dGroupModule()) +def Conv1dGroupModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 2e59db727341..7a7251555cd7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1230,7 +1230,6 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): # ============================================================================== - class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__() @@ -5125,7 +5124,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1Module()) @@ -5148,7 +5147,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) @@ -5159,6 +5158,52 @@ def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSpecialExpm1Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1Module()) +def ElementwiseSpecialExpm1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSpecialExpm1IntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1IntModule()) +def ElementwiseSpecialExpm1IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseRad2DegModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 84e0e2eb9cf5..e2eaa4cfd0fe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) +class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp2d = torch.nn.MaxPool2d( + kernel_size=6, + stride=6, + padding=3, + dilation=1, + ceil_mode=True, + ) + + @export + @annotate_args( + [ + None, + ([2, 6, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp2d(x) + + +@register_test_case( + module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule() +) +def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0)) + + # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index a8820f59c373..d1ddc42b39b1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1752,7 +1752,7 @@ def forward(self, x): return x.unfold(0, 0, 1) -@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero()) +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero_Size_Zero()) def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils): module.forward(tu.rand()) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index ad873201dbba..0439f8244a0b 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -0d5247caf3ffd618d31cf4cf880c47b7dbd323a7 +3f159d635772fa2a8fd352d96b95100d885f8169 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index c18413eacec9..7ab5a78d074f 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.6.0.dev20241107 +torch==2.6.0.dev20241216 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 5a5fb83d5fc0..30b85e63ab0f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1580,12 +1580,14 @@ func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor // ----- -// CHECK-LABEL: func.func @test_nonzero - func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> - %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> - return %0 : !torch.vtensor<[3,4,5],si64> - } +func.func @test_nonzero(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[NONZERO:.*]] = torch.aten.nonzero %arg0 : !torch.vtensor<[?],f32> -> !torch.vtensor<[?,1],si64> + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[NONZERO]], %[[ZERO]], %[[ONE]] : !torch.vtensor<[?,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64> + %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> + return %0 : !torch.vtensor<[1,?],si64> +} // ----- @@ -2055,22 +2057,30 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4] // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[1],si64> - // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_32:.*]] = torch.constant.none - // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" - // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } @@ -2107,23 +2117,30 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],si64> - // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_32:.*]] = torch.constant.none - // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" - // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> - // CHECK: } + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 9e504c082a8c..a3d52166385a 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2943,3 +2943,53 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s } // ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x4xi32>) -> tensor<6x4x1xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_9]], %[[VAL_10]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_14:.*]] = tosa.gather %[[VAL_8]], %[[VAL_13]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[3,4,2],f32> +// CHECK: } +func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %0 = torch.aten.unfold %arg0, %int0, %int2, %int2 : !torch.vtensor<[6,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$rank_zero( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.unfold %arg0, %int0, %int1, %int1 : !torch.vtensor<[],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- diff --git a/test/python/fx_importer/v2.3/auto_functionalized.py b/test/python/fx_importer/v2.3/auto_functionalized.py index ab7401dcc2fb..7fb0eeb3b67f 100644 --- a/test/python/fx_importer/v2.3/auto_functionalized.py +++ b/test/python/fx_importer/v2.3/auto_functionalized.py @@ -59,8 +59,9 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> - # CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]] + # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". + # COM: torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> () + # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} print(m) m.operation.verify() @@ -86,7 +87,8 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) - # CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0 + # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". + # COM: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%arg0) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} print(m) m.operation.verify() diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 8c8d45bea8a9..be1615525984 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.20.0.dev20241107 +torchvision==0.22.0.dev20241216