From d0a3cb45971634e35cb421e319ed30b038ce95ba Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 10 Dec 2024 12:37:40 +0530 Subject: [PATCH 01/17] build: manually update PyTorch version (#3896) This commit sets the PyTorch and TorchVision version to nightly release 2024-12-01. This commit also updates the test checks in `test/python/fx_importer/v2.3/auto_functionalized.py`. Failing tests are tracked through https://github.com/llvm/torch-mlir/issues/3796. --------- Signed-off-by: Vivek Khandelwal --- projects/pt1/e2e_testing/xfail_sets.py | 29 ++++++------------- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- .../fx_importer/v2.3/auto_functionalized.py | 10 ++++--- torchvision-requirements.txt | 2 +- 5 files changed, 18 insertions(+), 27 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7430ad89c2c2..9f832cb9e033 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -464,8 +464,6 @@ "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ScalarImplicitFloatModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", "SplitDimDynamicModule_basic", "SplitDimStaticModule_basic", "SqrtIntModule_basic", @@ -504,30 +502,21 @@ "CrossEntropyLossNoReductionModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", - "IndexPutImpl1DFloatAccumulateModule_basic", - "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl1DIntAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexPutImpl2DFloatNonAccumulateModule_basic", - "IndexPutImpl2DImplicitModule_basic", - "IndexPutImpl2DIndexModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", - "IndexPutImpl3DFloatNonAccumulateModule_basic", - "IndexPutImplIndexWithNoneModule_basic", "IsInfiniteModule_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "OneHotModule_basic", # RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSignbitModule_basic", "ElementwiseCopysignModule_basic", + "BernoulliFloatModule_basic", + "BernoulliTensorModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -856,8 +845,6 @@ "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", "SortTensorDescending_basic", "SortTensorInteger_basic", "SortTensorNegativeDimension_basic", @@ -932,7 +919,6 @@ "MeshgridIndexingXY_basic", "Meshgrid_basic", "MulIntModule_basic", - "OneHotModule_basic", "ReduceFrobeniusNormComplexModule_basic", "ScalarImplicitIntModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", @@ -951,10 +937,11 @@ "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", # RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "BernoulliFloatModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -979,6 +966,8 @@ # torch export: RuntimeError: cannot mutate tensors with frozen storage "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", } STABLEHLO_PASS_SET = { diff --git a/pytorch-hash.txt b/pytorch-hash.txt index ad873201dbba..ae415d496d6d 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -0d5247caf3ffd618d31cf4cf880c47b7dbd323a7 +798d5b7ddd08899fb62672d56044dbf1f63a4d17 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index c18413eacec9..83ecc622c492 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.6.0.dev20241107 +torch==2.6.0.dev20241201 diff --git a/test/python/fx_importer/v2.3/auto_functionalized.py b/test/python/fx_importer/v2.3/auto_functionalized.py index ab7401dcc2fb..7fb0eeb3b67f 100644 --- a/test/python/fx_importer/v2.3/auto_functionalized.py +++ b/test/python/fx_importer/v2.3/auto_functionalized.py @@ -59,8 +59,9 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> - # CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]] + # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". + # COM: torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> () + # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} print(m) m.operation.verify() @@ -86,7 +87,8 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) - # CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0 + # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". + # COM: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%arg0) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} print(m) m.operation.verify() diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 8c8d45bea8a9..e0583c31e56c 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.20.0.dev20241107 +torchvision==0.20.0.dev20241201 From 49b3d255774f55fcf2a92527b3163d7845e905d0 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 10 Dec 2024 22:22:51 +0100 Subject: [PATCH 02/17] Pin and update actions (#3907) This pins and updates most actions. The PR is limited to those actions that seem actively maintained and updated. The actions left unpined should be reevaluated and eventually replaced with other actions. The rational for pinning actions is to follow the suggestions by OpenSSF Scorecard, see https://github.com/ossf/scorecard/blob/main/docs/checks.md#pinned-dependencies. --- .github/actions/setup-build/action.yml | 4 +-- .github/workflows/RollPyTorch.yml | 8 +++--- .github/workflows/bazelBuildAndTest.yml | 6 ++--- .github/workflows/buildRelease.yml | 26 ++++++++++---------- .github/workflows/gh-pages-releases.yml | 2 +- .github/workflows/merge-rollpytorch.yml | 2 +- .github/workflows/oneshotSnapshotPackage.yml | 2 +- .github/workflows/pre-commit-all.yml | 6 ++--- .github/workflows/pre-commit.yml | 6 ++--- .github/workflows/releaseSnapshotPackage.yml | 6 ++--- 10 files changed, 34 insertions(+), 34 deletions(-) diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index a21c9a1d7296..7ed50e866492 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -27,7 +27,7 @@ runs: steps: - name: Set up Python if: ${{ runner.arch == 'X64' }} - uses: actions/setup-python@v4 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.11' @@ -74,7 +74,7 @@ runs: - name: Enable ccache if: ${{ inputs.cache-enabled == 'true' }} - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ github.workspace }}/.ccache key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 3c8b95a3181a..8c571893e145 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -22,7 +22,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'false' token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -95,7 +95,7 @@ jobs: - name: Post issue comment on build failure if: failure() - uses: peter-evans/create-or-update-comment@v2 + uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0 with: issue-number: 1690 body: | @@ -111,7 +111,7 @@ jobs: - name: Update PyTorch Build Cache (if running on main branch) if: github.ref_name == 'main' id: cache-pytorch - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} @@ -127,7 +127,7 @@ jobs: git pull origin main - name: Create pull request - uses: peter-evans/create-pull-request@v5.0.1 + uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f # v7.0.5 with: author: Roll PyTorch Action branch: rollpytorch diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 23f2addbe5af..747a8424d7c0 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -32,7 +32,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checkout torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' @@ -40,7 +40,7 @@ jobs: # restore to avoid the cache going stale over time # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache - name: Setup cache for bazel - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ~/.cache/bazel key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} @@ -102,7 +102,7 @@ jobs: - name: Send mail if: failure() - uses: dawidd6/action-send-mail@v3 + uses: dawidd6/action-send-mail@2cea9617b09d79a095af21254fbcb7ae95903dde # v3.12.0 with: server_address: ${{ secrets.SMTP_SERVER }} server_port: ${{ secrets.SMTP_PORT }} diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index e84aabb4b388..7b09cf050563 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -28,7 +28,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' fetch-depth: 0 @@ -59,7 +59,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -75,7 +75,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist @@ -96,7 +96,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' fetch-depth: 0 @@ -127,7 +127,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -143,7 +143,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist @@ -156,7 +156,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -187,7 +187,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -203,7 +203,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist @@ -216,7 +216,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -250,7 +250,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -267,7 +267,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: wheels path: dist @@ -285,7 +285,7 @@ jobs: steps: - name: Invoke Publish Releases Page - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Publish releases page token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index a0eb45257b11..112d4b4a8ee0 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -20,7 +20,7 @@ jobs: # existing lock files. sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: Run scrape releases script diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index 58a91fd1d409..26c6eba46571 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -18,7 +18,7 @@ jobs: steps: # Fetch the repo first so that the gh command knows where to look for the PR - name: Fetch Repo - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index ec1878606624..f3ab4be178ed 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -18,7 +18,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/pre-commit-all.yml b/.github/workflows/pre-commit-all.yml index e17d4ebdbb43..b370a2966968 100644 --- a/.github/workflows/pre-commit-all.yml +++ b/.github/workflows/pre-commit-all.yml @@ -8,8 +8,8 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --color=always --all-files diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 29733c2e5d45..fc1b6d2ab392 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -7,11 +7,11 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: # requites to grab the history of the PR fetch-depth: 0 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --color=always --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index 8a0ec914440f..812f5ce488a3 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -21,7 +21,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -58,14 +58,14 @@ jobs: prerelease: false - name: "Invoke workflow :: Build and Test" - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Build and Test token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} ref: "${{ env.tag_name }}" - name: "Invoke workflow :: Release Build" - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Release Build token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} From 59b3614e3c80458d43698a3a9842317f80701064 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 10 Dec 2024 22:32:53 +0100 Subject: [PATCH 03/17] Replace `ubuntu-latest` with specific version (#3906) While `ubuntu-latest` uses Ubuntu 22.04 for now, thils will change soon (rollout already started), see https://github.com/actions/runner-images/issues/10636. The version can be updated from 22.04 to 24.04 in a follow up. --- .github/workflows/bazelBuildAndTest.yml | 2 +- .github/workflows/buildRelease.yml | 2 +- .github/workflows/gh-pages-releases.yml | 2 +- .github/workflows/merge-rollpytorch.yml | 2 +- .github/workflows/pre-commit-all.yml | 2 +- .github/workflows/pre-commit.yml | 2 +- .github/workflows/releaseSnapshotPackage.yml | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 747a8424d7c0..4eeef0b9bb5e 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -22,7 +22,7 @@ concurrency: jobs: ubuntu-build: name: ubuntu-x86_64 - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Prepare workspace diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 7b09cf050563..a304672b474f 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -273,7 +273,7 @@ jobs: path: dist publish_releases: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 needs: - build_linux - build_linux_arm64 diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index 112d4b4a8ee0..e87630edb28c 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -8,7 +8,7 @@ on: jobs: scrape_and_publish_releases: name: "Scrape and publish releases" - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index 26c6eba46571..e335f1fdfd7d 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -9,7 +9,7 @@ on: jobs: merge-pr: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 if: | github.repository == 'llvm/torch-mlir' && github.event.workflow_run.actor.login == 'stellaraccident' && diff --git a/.github/workflows/pre-commit-all.yml b/.github/workflows/pre-commit-all.yml index b370a2966968..2c0d61e92747 100644 --- a/.github/workflows/pre-commit-all.yml +++ b/.github/workflows/pre-commit-all.yml @@ -6,7 +6,7 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index fc1b6d2ab392..6a848fe8674f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -5,7 +5,7 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index 812f5ce488a3..b6822b3701d6 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -9,7 +9,7 @@ on: jobs: release_snapshot_package: name: "Tag snapshot release" - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' steps: From 31b912e83e2ccf714eef79229341a4b8c0c2bb3d Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 10 Dec 2024 22:42:24 +0100 Subject: [PATCH 04/17] Replace unmaintained `create-release` action (#3905) This replaces the `actions/create-release` with `ncipollo/release-action` as the former is unmaintained. --- .github/workflows/oneshotSnapshotPackage.yml | 9 ++++----- .github/workflows/releaseSnapshotPackage.yml | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index f3ab4be178ed..92d732cea3a6 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -43,16 +43,15 @@ jobs: - name: Create Release id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 with: - tag_name: ${{ env.tag_name }} - release_name: torch-mlir snapshot ${{ env.tag_name }} + tag: ${{ env.tag_name }} + name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: "Invoke workflow :: Build and Test" uses: benc-uk/workflow-dispatch@v1 diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index b6822b3701d6..7b575764ac8e 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -46,16 +46,15 @@ jobs: - name: Create Release id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 with: - tag_name: ${{ env.tag_name }} - release_name: torch-mlir snapshot ${{ env.tag_name }} + tag: ${{ env.tag_name }} + name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: "Invoke workflow :: Build and Test" uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 From 5a5cc6b34117e9956a4c7438afa8d83ae0bb9ee6 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 11 Dec 2024 10:36:51 +0530 Subject: [PATCH 05/17] [MLIR][TORCH] Add aten.special.expm1 op lowering (#3878) This commit adds the support for torch.aten.special.expm1 op by decomposing it into torch.aten.expm1 op. --------- Signed-off-by: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 9 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 14 ++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 12 +++-- .../build_tools/abstract_interp_lib_gen.py | 8 +++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 50 ++++++++++++++++++- 8 files changed, 112 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f951de9af795..556b0aa76e93 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4610,6 +4610,29 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ }]; } +def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::special_expm1 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSpecialExpm1Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSpecialExpm1Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index edcc81a2847f..fb0aaa7201b8 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6495,6 +6495,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.special_expm1\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -11589,6 +11593,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.special_expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 919c4727b1f9..063dca041901 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11177,6 +11177,19 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenSpecialExpm1Op + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSpecialExpm1Op op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11462,6 +11475,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f868c4c1800a..25635d2c5c46 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -569,6 +569,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9f832cb9e033..d2c6e6c9a762 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -500,8 +500,6 @@ "AdaptiveMaxPool1dStatic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "IsInfiniteModule_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", @@ -909,8 +907,6 @@ "AtenItemIntOpModule_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "IscloseStaticModuleTrue_basic", @@ -1209,6 +1205,8 @@ "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", @@ -2951,6 +2949,8 @@ "ElementwiseEluNonDefaultModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseCreateComplexModule_basic", "ElementwiseMulTensorComplexModule_basic", @@ -3662,6 +3662,8 @@ "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseWhereScalarOtherStaticModule_basic", @@ -4355,6 +4357,8 @@ "ElementwiseSinIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseTanIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 331aa476910e..2a980bf534fd 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -222,6 +222,9 @@ def aten〇exp2〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇special_expm1〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇isfinite〡shape(self: List[int]) -> List[int]: return self @@ -2717,6 +2720,11 @@ def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇special_expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 8a0417a85189..4c2de094e109 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -452,6 +452,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) + emit("aten::special_expm1 : (Tensor) -> (Tensor)") emit_with_mutating_variants( "aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 38fccc06b393..b1745fa5b85a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -5207,7 +5207,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1Module()) @@ -5230,7 +5230,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) @@ -5241,6 +5241,52 @@ def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSpecialExpm1Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1Module()) +def ElementwiseSpecialExpm1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSpecialExpm1IntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1IntModule()) +def ElementwiseSpecialExpm1IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseRad2DegModule(torch.nn.Module): def __init__(self): super().__init__() From f03a5762c3598da39ac44f1edbc7aa4579ef3262 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Thu, 12 Dec 2024 04:08:27 -0500 Subject: [PATCH 06/17] [TorchToTosa] Refactoring to separate construction of legal/illegal ops and conversion patterns. (#3759) This PR refactors TorchToTosa to separate the construction of legal/illegal ops and conversion patterns in their own functions: 1. populateTorchToTosaConversionLegalOps -- populate any ops that are legal after the conversion pass 2. populateTorchToTosaConversionIllegalOps -- populate any ops that are illegal after the conversion pass 3. populateTorchToTosaConversionPatterns -- populate the ops conversion patterns Currently the (il)legality of the ops that are (il)legal after the conversion pass runs is embedded within the conversion pattern. Our end goal is to write a new pass pipeline that converts `torch` ops to a mix of `tosa`, `linalg`, `tensor`, etc dialect ops. The reason we want to also emit `tosa` ops (instead of using the existing `TorchToLinalg` to emit `linalg`+`tensor`+...) is because some operations like `conv2d` encodes the padding behavior in the op in `tosa` unlike the `linalg` version -- this helps in lowering the `tosa.conv2d` to a custom implementation that does padding on the fly. To implement this new pipeline we need to be able to separate out the illegal `tosa` ops from the conversion pattern itself. Otherwise we will hit an issue for ops like `AtenMaxDimOp` which can be lowered to both `tosa` and `linalg + others` dialects. Not all `AtenMaxDimOp` can be lowered successfully to `tosa` as the implementation uses `tosa.reshape` which cannot handle multiple dynamic dimensions but the `TorchToLinalg` lowering can handle it. In the current behavior the pipeline will stop as soon as the existing `TorchToTosa` conversion runs as `AtenMaxDimOp` will be marked as an illegal op. Essentially we want to be able to control what the legality of the ops should be independent of the conversion pattern. This is also inline with the conversion patterns in the llvm-mlir repo such as https://github.com/llvm/llvm-project/blob/000e790be35b77a01872851646d54432a203542c/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp#L718 "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY." --- .../Conversion/TorchToTosa/TorchToTosa.h | 15 +- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 456 +++++++++--------- 2 files changed, 249 insertions(+), 222 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db1..221745b1c26e 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -12,12 +12,25 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + #include namespace mlir { namespace torch { + +/// Collect a set of legal/illegal ops for converting Torch operations to Tosa +/// dialect. +void populateTorchToTosaConversionLegalOps(ConversionTarget &target); + +/// Collect a set of patterns to convert Torch operations to Tosa dialect + +/// return the set of illegalOps +std::set +populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, + RewritePatternSet &patterns); + std::unique_ptr> createConvertTorchToTosaPass(); -} +} // namespace torch } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 9572723fdd29..1c05ae49e18b 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8277,342 +8277,356 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { ConversionTarget target(*context); target.addLegalDialect(); + target.addIllegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - // The following ops are never the primary reason why lowering fails. - // The backend contract only allows functions to return tensors thus there - // is always another op using them. - // When we have a chain of torch.constant.int followed by a unsupported - // torch op, we want the pass to mention the unsupported torch op - // in the error message. - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalDialect(); + populateTorchToTosaConversionLegalOps(target); RewritePatternSet patterns(context); + auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps( + typeConverter, patterns); + + for (auto op : illegalOps) { + target.addIllegalOp(OperationName(op, context)); + } + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { + // The following ops are never the primary reason why lowering fails. + // The backend contract only allows functions to return tensors thus there + // is always another op using them. + // When we have a chain of torch.constant.int followed by a unsupported + // torch op, we want the pass to mention the unsupported torch op + // in the error message. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); +} + +std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + + MLIRContext *context = patterns.getContext(); + std::set illegalOps; + #define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) #undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN #define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) - INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) - INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) - INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) - INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) - INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) - INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) - INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) + INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) + INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) + INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) + INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) + INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) - INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) - INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) - INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) - INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) - INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, - tosa::LogicalLeftShiftOp) - INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, - tosa::ArithmeticRightShiftOp) + INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) + INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) + INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) #undef INSERT_BINARY_COMPARE_PATTERN #define INSERT_BINARY_MUL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); - INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); #undef INSERT_BINARY_MUL_PATTERN #define INSERT_BINARY_DIV_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); #undef INSERT_BINARY_DIV_PATTERN #define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); #undef INSERT_REMAINDER_FMOD_OP_PATTERN #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, - mlir::tosa::convertReduceMeanOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, - mlir::tosa::convertReduceSumOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, - mlir::tosa::convertLinalgVectorNormOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, + mlir::tosa::convertReduceMeanOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, + mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, - mlir::tosa::convertReduceAllOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, - mlir::tosa::convertReduceProdOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, + mlir::tosa::convertReduceAnyOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, + mlir::tosa::convertReduceAllOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ONEDIM_REDUCTION_OP_PATTERN #define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>( \ typeConverter, context); - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, - mlir::tosa::convertReduceAllOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, - mlir::tosa::convertReduceSumOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, - mlir::tosa::convertReduceMaxOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, - mlir::tosa::convertReduceMinOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, - mlir::tosa::convertReduceProdOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, + mlir::tosa::convertReduceProdOp) #undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN #define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); #undef INSERT_INDICES_REDUCTION_OP_PATTERN #define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) #undef INSERT_SQUEEZE_OP_PATTERN #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); #undef INSERT_MATMUL_ATEMOP_PATTERN #define INSERT_MM_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MM_ATENOP_PATTERN(AtenMmOp); - INSERT_MM_ATENOP_PATTERN(AtenBmmOp); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); #undef INSERT_MM_ATEMOP_PATTERN #define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN #define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, - tosa::AvgPool2dOp); + INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, + tosa::AvgPool2dOp); #undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenMaxPool1dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool2dOp::getOperationName()); + patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + illegalOps.insert(AtenAvgPool1dOp::getOperationName()); + patterns.add(typeConverter, context); #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); - INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); - INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_FILL_PATTERN(AtenFill_ScalarOp); - INSERT_FILL_PATTERN(AtenFillScalarOp); - INSERT_FILL_PATTERN(AtenFillTensorOp); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); #undef INSERT_FILL_PATTERN #define INSERT_MASKED_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); #undef INSERT_MASKED_FILL_PATTERN #define INSERT_POW_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); - INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); - INSERT_POW_OP_PATTERN(AtenPowScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); #undef INSERT_POW_OP_PATTERN +#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); +#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN + #define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, \ context); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); #undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN -#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); -#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN - #define INSERT_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenReluOp); - INSERT_ATENOP_PATTERN(AtenLeakyReluOp); - INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenRsubScalarOp); - INSERT_ATENOP_PATTERN(AtenConvolutionOp); - INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); - INSERT_ATENOP_PATTERN(AtenReshapeOp); - INSERT_ATENOP_PATTERN(AtenBatchNormOp); - INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); - INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); - INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); - INSERT_ATENOP_PATTERN(AtenPermuteOp); - INSERT_ATENOP_PATTERN(AtenLog2Op); - INSERT_ATENOP_PATTERN(AtenThresholdOp); - INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); - INSERT_ATENOP_PATTERN(AtenDropoutOp); - INSERT_ATENOP_PATTERN(AtenViewOp); - INSERT_ATENOP_PATTERN(AtenGeluOp); - INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); - INSERT_ATENOP_PATTERN(AtenEmbeddingOp); - INSERT_ATENOP_PATTERN(AtenTransposeIntOp); - INSERT_ATENOP_PATTERN(AtenSliceTensorOp); - INSERT_ATENOP_PATTERN(AtenBroadcastToOp); - INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenAbsOp); - INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenClampOp); - INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); - INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenCopyOp); - INSERT_ATENOP_PATTERN(AtenToDtypeOp); - INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); - INSERT_ATENOP_PATTERN(AtenCatOp); - INSERT_ATENOP_PATTERN(AtenSqrtOp); - INSERT_ATENOP_PATTERN(AtenIscloseOp); - INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); - INSERT_ATENOP_PATTERN(AtenTrilOp); - INSERT_ATENOP_PATTERN(AtenDiagonalOp); - INSERT_ATENOP_PATTERN(AtenIndexSelectOp); - INSERT_ATENOP_PATTERN(AtenFlipOp); - INSERT_ATENOP_PATTERN(AtenRoundOp); - INSERT_ATENOP_PATTERN(AtenScatterSrcOp); - INSERT_ATENOP_PATTERN(AtenSliceScatterOp); - INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); - INSERT_ATENOP_PATTERN(AtenUniformOp); - INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); - INSERT_ATENOP_PATTERN(AtenAsStridedOp); - INSERT_ATENOP_PATTERN(AtenClampTensorOp); - INSERT_ATENOP_PATTERN(PrimsCollapseOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); - INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); - INSERT_ATENOP_PATTERN(PrimsSplitDimOp); - INSERT_ATENOP_PATTERN(AtenOuterOp); - INSERT_ATENOP_PATTERN(AtenLogitOp); - INSERT_ATENOP_PATTERN(AtenLog1pOp); - INSERT_ATENOP_PATTERN(AtenLog10Op); - INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenLeakyReluOp); + INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); + INSERT_ATENOP_PATTERN(AtenConvolutionOp); + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReshapeOp); + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenThresholdOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenDropoutOp); + INSERT_ATENOP_PATTERN(AtenViewOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenTransposeIntOp); + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenAbsOp); + INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); + INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); + INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); + INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); + INSERT_ATENOP_PATTERN(AtenOuterOp); + INSERT_ATENOP_PATTERN(AtenLogitOp); + INSERT_ATENOP_PATTERN(AtenLog1pOp); + INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenTanOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ + illegalOps.insert(AtenOp::getOperationName()); \ patterns.add>(typeConverter, context); - INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); #undef INSERT_CLONE_ATENOP_PATTERN - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); - } -}; -} // namespace + return illegalOps; +} std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { From 2c72a82e60dfbedfdccf6c4c77140bf61ec7a597 Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:19:00 -0800 Subject: [PATCH 07/17] [ONNX] Fix nonzero output type difference between onnx and torch (#3916) The onnx output tensor has a shape of ((n, z)), where (n) is the number of dimensions in the input tensor and (z) is the number of non-zero elements2. This is different from PyTorch's default behavior, where the dimensions are reversed. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 41 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 14 ++++--- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7446b7faaa08..13f555c146b4 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1093,18 +1093,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.replaceOp(binder.op, nllLoss); return success(); }); - patterns.onOp("NonZero", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) { - return failure(); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + auto rawSize = resultType.getSizes(); + SmallVector torchResultSize(rawSize.rbegin(), rawSize.rend()); + auto torchResultType = rewriter.getType( + torchResultSize, resultType.getDtype()); + auto nonZero = rewriter.create( + binder.getLoc(), torchResultType, operand); + // The output tensor has a shape of ((n, z)), where (n) is the + // number of dimensions in the input tensor and (z) is the + // number of non-zero elements2. This is different from + // PyTorch's default behavior, where the dimensions are + // reversed. + rewriter.replaceOpWithNewOp( + binder.op, resultType, nonZero, zero, one); + return success(); + }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 5a5fb83d5fc0..7f1e63d83ccd 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1580,12 +1580,14 @@ func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor // ----- -// CHECK-LABEL: func.func @test_nonzero - func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> - %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> - return %0 : !torch.vtensor<[3,4,5],si64> - } +func.func @test_nonzero(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[NONZERO:.*]] = torch.aten.nonzero %arg0 : !torch.vtensor<[?],f32> -> !torch.vtensor<[?,1],si64> + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[NONZERO]], %[[ZERO]], %[[ONE]] : !torch.vtensor<[?,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64> + %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> + return %0 : !torch.vtensor<[1,?],si64> +} // ----- From 8e0eafd022cd7555c8b58927d3238a7a89e9dbd4 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 13 Dec 2024 11:05:40 +0530 Subject: [PATCH 08/17] [MLIR][TORCH] Add support for 1-d group convolution (#3904) This commit adds the support for 1-d group convolution by transforming it into a 2-d group convolution which is already supported. This commit also refactors the unsqueeze and squeeze tensor utility. --------- Signed-off-by: Vivek Khandelwal --- include/torch-mlir/Conversion/Utils/Utils.h | 9 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 98 ++------------- lib/Conversion/TorchToLinalg/Linear.cpp | 72 +++++++++-- lib/Conversion/Utils/Utils.cpp | 113 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 4 + .../torch_mlir_e2e_test/test_suite/conv.py | 27 +++++ 6 files changed, 230 insertions(+), 93 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index d21dd5504dcd..264fb4966d39 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize); +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a18c0bae01fc..b8c20bc73f65 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Value input = adaptor.getSelf(); - auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); - - if (inputRank == 0) { - return rewriter.notifyMatchFailure( - op, "zero input rank should have been handled by the folder"); - } - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - - // assert dynamic squeeze dim size == 1 - if (inputType.isDynamicDim(dim)) { - Value cstDim = rewriter.create(op.getLoc(), dim); - Value dimVal = rewriter.create(op.getLoc(), input, cstDim); - Value cstOne = rewriter.create(op.getLoc(), 1); - Value cmp = rewriter.create( - op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); - rewriter.create( - op.getLoc(), cmp, - rewriter.getStringAttr( - "Expected dynamic squeeze dim size to be statically 1")); - } - - const TypeConverter *typeConverter = getTypeConverter(); - auto resultType = - cast(typeConverter->convertType(op.getType())); - int64_t resultRank = resultType.getRank(); - // If the dim(th) dimension of operand tensor type is not statically unit, - // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { - rewriter.replaceOpWithNewOp(op, resultType, input); - return success(); + auto squeezeTensorInfo = + squeezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(squeezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - SmallVector reassociationMap(resultRank); - bool alreadyCrossedSqueezedDim = false; - for (int i = 0; i != resultRank; i++) { - if (alreadyCrossedSqueezedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (dim != 0 && i != dim - 1) - continue; - - alreadyCrossedSqueezedDim = true; - if (dim == 0) - reassociationMap[0].push_back(1); - if (i == dim - 1) - reassociationMap[i].push_back(dim); - } - } - // Note: In case the operand tensor type is of unit rank and is statically - // shaped with unit dimension, the `reassociationMap` will be empty and the - // input will be collapsed to a 0-D tensor. - rewriter.replaceOpWithNewOp(op, resultType, input, - reassociationMap); + rewriter.replaceOp(op, squeezeTensorInfo.value()); return success(); } }; @@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - auto inputRank = - cast(adaptor.getSelf().getType()).getRank(); - dim = toPositiveDim(dim, inputRank + 1); - if (!isValidDim(dim, inputRank + 1)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector reassociationMap(inputRank); - // From the perspective of the reassociation map, the situation of - // unsqueezing before or after the last dimension is symmetrical. - // Normalize it to the "before" case. - // The 0 case is special here, since there is no last dimension to insert - // before -- we simply rely on the loop below iterating 0 times. - if (dim == inputRank && inputRank != 0) - dim = inputRank - 1; - bool alreadyCrossedExpandedDim = false; - for (int i = 0; i != inputRank; i++) { - if (alreadyCrossedExpandedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (i == dim) { - reassociationMap[i].push_back(i + 1); - alreadyCrossedExpandedDim = true; - } - } + auto unsqueezeTensorInfo = + unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(unsqueezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - auto resultType = cast( - getTypeConverter()->convertType(op->getResult(0).getType())); - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getSelf(), reassociationMap); + + rewriter.replaceOp(op, unsqueezeTensorInfo.value()); return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9ec7761704ea..4e93804b9ca5 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -850,6 +850,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int dilations"); + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); + + // Adding support for 1d group convolution by converting the 1d-conv to + // 2d-conv. + // TODO: Replace this logic with the appropriate linalg op for 1-d group + // convolution once that support is added. + bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1); + if (is1DGroupConv) { + // Unsqueezing the last dim of input and weight. Also extending the + // dilation, stride, padding, and output padding lists. + auto unsqueezeInputInfo = + unsqueezeTensor(rewriter, op, input, /*dim=*/-1); + if (failed(unsqueezeInputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + input = unsqueezeInputInfo.value(); + + auto unsqueezeWeightInfo = + unsqueezeTensor(rewriter, op, weight, /*dim=*/-1); + if (failed(unsqueezeWeightInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + weight = unsqueezeWeightInfo.value(); + + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + paddingIntValues.push_back(cstZero); + outputPaddingIntValues.push_back(cstZero); + strideInts.push_back(1); + dilationInts.push_back(1); + + inRank++; + numSpatialDims++; + } + Value inBatch = getDimOp(rewriter, loc, input, 0); Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; @@ -861,13 +903,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Checks for valid group size - int64_t numGroups; - if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) - return rewriter.notifyMatchFailure(op, - "only constant group size supported."); - Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); - auto validate = [&](Value toValidate, std::string err) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); @@ -1280,13 +1315,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } + rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } if (numSpatialDims != 2) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); + op, "unimplemented: only 1D and 2D grouped convolution supported"); // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { @@ -1371,6 +1417,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index e3f5b6d0299a..72217e5f4afd 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -447,6 +447,119 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, return castIntToIndex(rewriter, loc, boundedByDimSize); } +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + ArrayRef inputShape = inputType.getShape(); + + // `input` has a reduced rank. Hence add 1. + int64_t unsqueezedRank = inputShape.size() + 1; + dim = toPositiveDim(dim, unsqueezedRank); + if (!isValidDim(dim, unsqueezedRank)) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + + SmallVector unsqueezedShape{inputShape}; + unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1); + Type unsqueezedType = + RankedTensorType::get(unsqueezedShape, inputType.getElementType()); + + SmallVector reassociationMap(inputRank); + // From the perspective of the reassociation map, the situation of + // unsqueezing before or after the last dimension is symmetrical. + // Normalize it to the "before" case. + // The 0 case is special here, since there is no last dimension to insert + // before -- we simply rely on the loop below iterating 0 times. + if (dim == inputRank && inputRank != 0) + dim = inputRank - 1; + bool alreadyCrossedExpandedDim = false; + for (int i = 0; i != inputRank; i++) { + if (alreadyCrossedExpandedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (i == dim) { + reassociationMap[i].push_back(i + 1); + alreadyCrossedExpandedDim = true; + } + } + } + Value unsqueezed = rewriter.create( + op->getLoc(), unsqueezedType, input, reassociationMap); + return unsqueezed; +} + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + Location loc = op->getLoc(); + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + // No scope for squeezing the input. + if (inputRank == 0) + return input; + + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + // assert dynamic squeeze dim size == 1 + if (inputType.isDynamicDim(dim)) { + Value cstDim = rewriter.create(loc, dim); + Value dimVal = rewriter.create(loc, input, cstDim); + Value cstOne = rewriter.create(loc, 1); + Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, + dimVal, cstOne); + rewriter.create( + loc, cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); + } + + ArrayRef inputShape = inputType.getShape(); + SmallVector squeezedShape; + squeezedShape.append(inputShape.begin(), inputShape.begin() + dim); + squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end()); + int64_t squeezedRank = inputRank - 1; + Type squeezedType = + RankedTensorType::get(squeezedShape, inputType.getElementType()); + + // If the dim(th) dimension of operand tensor type is not statically unit, + // squeeze will behave as an identity operation. + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { + return input; + } + + SmallVector reassociationMap(squeezedRank); + bool alreadyCrossedSqueezedDim = false; + for (int i = 0; i != squeezedRank; i++) { + if (alreadyCrossedSqueezedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (dim != 0 && i != dim - 1) + continue; + + alreadyCrossedSqueezedDim = true; + if (dim == 0) + reassociationMap[0].push_back(1); + if (i == dim - 1) + reassociationMap[i].push_back(dim); + } + } + // Note: In case the operand tensor type is of unit rank and is statically + // shaped with unit dimension, the `reassociationMap` will be empty and the + // input will be collapsed to a 0-D tensor. + Value squeezed = rewriter.create( + op->getLoc(), squeezedType, input, reassociationMap); + return squeezed; +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d2c6e6c9a762..fe3aa3c5dd41 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2731,6 +2731,7 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -2886,6 +2887,7 @@ "Conv1dModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -3593,6 +3595,7 @@ "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -4186,6 +4189,7 @@ "Conv1dWithSamePaddingModule_basic", "Conv1dWithValidPaddingModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 7a45dd7fc0ce..663c4b6a746b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1199,6 +1199,33 @@ def Conv1dWithValidPaddingModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dGroupModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d( + inputVec, weight, bias=bias, stride=[1], padding=[0], dilation=[1], groups=2 + ) + + +@register_test_case(module_factory=lambda: Conv1dGroupModule()) +def Conv1dGroupModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() From 71cb94268200003ecafad76788212df8fc61c824 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 17 Dec 2024 08:03:58 -0800 Subject: [PATCH 09/17] [torch-mlir][sparse] register sparse tensor dialect for all rewriting (#3918) We incorrectly relied on the fact that StableHLO registers the sparse tensor dialect, but when building for e.g. just LinAlg, the dependency was missing. This fixes this shortcoming. FIXES: https://github.com/llvm/torch-mlir/issues/3816 --- lib/InitAll.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index c9638c8353b1..d9d7ef1a0cd4 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -52,7 +53,8 @@ void mlir::torch::registerOptionalInputDialects( mlir::DialectRegistry ®istry) { registry.insert(); + scf::SCFDialect, sparse_tensor::SparseTensorDialect, + tensor::TensorDialect, tosa::TosaDialect>(); } void mlir::torch::registerAllPasses() { From e68560d713e37f88f446e69979692ee4ef7a64b0 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Wed, 18 Dec 2024 19:42:23 -0800 Subject: [PATCH 10/17] Add attributes support for onnx.nms (#3920) - Set default attribute values - Support `max_output_boxes_per_class` attribute - e2e test `test_nonmaxsuppression_limit_output_size` passed --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 127 +++++++++++------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 81 ++++++----- 2 files changed, 123 insertions(+), 85 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 13f555c146b4..12d8683bc9d1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3688,6 +3688,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( patterns.onOp( "NonMaxSuppression", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; SmallVector operands; int64_t centerPointBox; @@ -3702,34 +3703,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "unimplemented: expected center_point_box " "attribute value to be 0"); - // TODO: Add support for optional arguments to be absent. - if (operands.size() < 4) - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: expected at least 4 arguments"); - + // TODO: Support multiple batches and classes // Squeeze the boxes and scores tensor. // In Onnx, the shape of boxes is [BxNx4] while the // torchvision expects it to be of shape [Nx4]. Similarly, for // the scores tensor shape in Onnx is [BxCxN] while the // torchvision expects it to be of shape [N]. Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, boxes); + FailureOr squeezedBoxes = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); if (failed(squeezedBoxes)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze boxes tensor"); - - FailureOr squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, scores); + FailureOr squeezedScores = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor( - rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value()); + squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, + squeezedScores.value()); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); - boxes = squeezedBoxes.value(); scores = squeezedScores.value(); @@ -3737,61 +3732,103 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Filter out the boxes if the score < score_threshold if (operands.size() == 5) { Value scoreThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), - operands[4]); + loc, rewriter.getType(), operands[4]); Value minScores = rewriter.create( - binder.getLoc(), + loc, Torch::ValueTensorType::get(binder.op->getContext(), SmallVector{}, rewriter.getF32Type()), scores); minScores = rewriter.create( - binder.getLoc(), rewriter.getType(), minScores); + loc, rewriter.getType(), minScores); Value scoresCond = rewriter.create( - binder.getLoc(), minScores, scoreThreshold); + loc, minScores, scoreThreshold); rewriter.create( - binder.getLoc(), scoresCond, + loc, scoresCond, rewriter.getStringAttr( "unimplemented: score_threshold should be <= min(scores)")); } - // TODO: Support default iou_threshold - Value iouThreshold = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[3]); + // Get max_output_boxes_per_class and iou_threshold + Value cst0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value maxOutputBoxesPerClass = cst0; + Value iouThreshold = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0)); + if (operands.size() > 3 && + !isa(operands[3].getType())) { + iouThreshold = rewriter.create( + loc, rewriter.getType(), operands[3]); + } + if (operands.size() > 2 && + !isa(operands[2].getType())) { + maxOutputBoxesPerClass = rewriter.create( + loc, rewriter.getType(), operands[2]); + } + auto nmsTy = Torch::ValueTensorType::get( + binder.op->getContext(), SmallVector{-1}, + rewriter.getIntegerType(64, /*signed=*/true)); + Value result = rewriter.create( + loc, nmsTy, boxes, scores, iouThreshold); + + // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class + Value numOutputBoxes = + rewriter.create(loc, result, cst0); + Value boxesCond = rewriter.create( + loc, numOutputBoxes, maxOutputBoxesPerClass); + + auto nmsResultTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{resultType.getSizes()[0]}, rewriter.getIntegerType(64, /*signed=*/true)); - Value result = rewriter.create( - binder.getLoc(), nmsTy, boxes, scores, iouThreshold); + auto ifSlice = rewriter.create( + loc, TypeRange({nmsResultTy}), boxesCond); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getThenRegion(), + ifSlice.getThenRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, + /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); + rewriter.create(loc, curResult); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getElseRegion(), + ifSlice.getElseRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result); + rewriter.create(loc, curResult); + } + result = ifSlice.getResult(0); // The result generated by torchvision.nms op is of shape [n], while the // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor // and make it of shape [n, 1] and then concatenate it with a zero // tensor of shape [n, 2] to make it of shape [n, 3]. - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, dim); + Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); if (failed(unsqueezedResult)) return rewriter.notifyMatchFailure( binder.op, "failed to unsqueeze result tensor"); result = unsqueezedResult.value(); - Value numOutputBoxes = rewriter.create( - binder.getLoc(), result, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0))); + numOutputBoxes = + rewriter.create(loc, result, cst0); SmallVector zerosShapeValues{numOutputBoxes}; zerosShapeValues.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2))); + loc, rewriter.getI64IntegerAttr(2))); Value zerosShapeList = rewriter.create( - binder.getLoc(), + loc, rewriter.getType( rewriter.getType()), zerosShapeValues); - std::optional> resultShape = cast(result.getType()).getOptionalSizes(); if (!resultShape.has_value()) @@ -3800,10 +3837,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( llvm::SmallVector zerosShape = {resultShape->front(), 2}; auto zerosTy = Torch::ValueTensorType::get( resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = rewriter.create(loc); Value zeros = rewriter.create( - binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone, - cstNone); + loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); Type listElemType = cast(resultType) @@ -3811,22 +3847,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value tensorList = rewriter.create( - binder.getLoc(), listType, SmallVector{zeros, result}); - - // TODO: Support max_output_boxes_per_class input - // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class - Value maxOutputBoxesPerClass = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[2]); - Value boxesCond = rewriter.create( - binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass); - rewriter.create( - binder.getLoc(), boxesCond, - rewriter.getStringAttr( - "unimplemented: number of output boxes per class should be " - "<= max_output_boxes_per_class")); - + loc, listType, SmallVector{zeros, result}); rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, dim); + tensorList, cst1); return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 7f1e63d83ccd..30b85e63ab0f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -2057,22 +2057,30 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4] // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[1],si64> - // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_32:.*]] = torch.constant.none - // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" - // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } @@ -2109,23 +2117,30 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],si64> - // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_32:.*]] = torch.constant.none - // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_27]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" - // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> - // CHECK: } + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } From 061bbc5e1bc4f7880bb565e404a6709f97396818 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 19 Dec 2024 10:55:15 -0800 Subject: [PATCH 11/17] [torch] Update `torch.bmm` to use accumulator type (#3924) Batch matmul was using the result type as the accumulator. Updated to use the preferred accumulator based on input type. --- lib/Conversion/TorchToLinalg/Linear.cpp | 10 ++++++-- .../torch_mlir_e2e_test/test_suite/basic.py | 23 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 4e93804b9ca5..9073c5846f33 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -727,15 +727,21 @@ class ConvertAtenBmmOp : public OpConversionPattern { // Check the matrixs shapes are valid for mulplication. checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); + Type accumulatorDType = getDefaultAccType(rewriter, resultElementType); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, - resultElementType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, accumulatorDType); Value bmm = rewriter .create(loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0) .getResult(0); + + if (accumulatorDType != resultElementType) { + bmm = torch_to_linalg::convertTensorToElementType(rewriter, loc, bmm, + resultElementType); + } + rewriter.replaceOpWithNewOp(op, newResultType, bmm); return success(); } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 5e3aa3bc02f6..bd6f069ee9db 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -87,6 +87,29 @@ def BmmFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) +class BmmFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float16, True), + ([-1, -1, -1], torch.float16, True), + ] + ) + def forward(self, lhs, rhs): + return torch.bmm(lhs, rhs) + + +@register_test_case(module_factory=lambda: BmmFloat16Module()) +def BmmFloat16Module_basic(module, tu: TestUtils): + module.forward( + tu.rand(3, 4, 5).to(torch.float16), tu.rand(3, 5, 4).to(torch.float16) + ) + + class BmmIntModule(torch.nn.Module): def __init__(self): super().__init__() From 51da49c3c582ac43b40416e323057290f3ad998b Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:40:39 -0800 Subject: [PATCH 12/17] [Torch] Add decomposition for 1d torch.nonzero (#3876) 2d static nonzero also work. But 2d dynamic need to be fixed next. --- .../Torch/Transforms/DecomposeComplexOps.cpp | 235 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 3 +- .../torch_mlir_e2e_test/test_suite/basic.py | 23 ++ 3 files changed, 260 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 063dca041901..24eb589cc397 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5705,6 +5705,240 @@ class DecomposeAtenConvolutionBackwardOp }; } // namespace +/** + * # one dim input + * t = torch.tensor([0, 0, 1, 1, 0, 0] + * # t_flat:[0, 0, 1, 1, 0, 0] + * t_flat = t.flatten(0, 0) + * nonzero_mask = t_flat != 0 + * # nonzero_mask:[0, 0, 1, 1, 0, 0] + * nonzero_mask = nonzero_mask.long() + * # destination_indices:[-1, -1, 0, 1, 1, 1] + * destination_indices = torch.cumsum(nonzero_mask, 0) - 1 + * # destination_indices_clamp:[0, 0, 0, 1, 1, 1] + * destination_indices_clamp = torch.clamp(destination_indices, min=0) + * # iota:[0, 0, 2, 3, 0, 0] + * iota = torch.arange(t_flat.size(0)) * nonzero_mask + * # scatter_self:[0, 0, 0, 0, 0, 0] + * scatter_self = torch.zeros_like(t_flat, dtype=torch.int64) + * # compacted:[2, 3, 0, 0, 0, 0] + * compacted = torch.scatter_add( + * scatter_self, dim=0, index=destination_indices_clamp, src=iota + * ) + * # result_flat:[2, 3] + * result_flat = compacted[: torch.sum(nonzero_mask)] + * + * # multi dim support + * original_shape = t.shape + * # input_shape_tensor:[6] + * input_shape_tensor = torch.tensor(original_shape) + * strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) + * + * one = torch.tensor([1]) + * if(t.dim() > 1): + * slicedStrides = strides[1:-1] + * strides = torch.cat([slicedStrides, one]) + * else: + * strides = one + * # a: tensor([[2], [3]]) torch.Size([2, 1]) + * a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1]) + * # b: tensor([[1]]) torch.Size([1, 1]) + * b = strides.unsqueeze(0) + * # c: tensor([[2], [3]]) torch.Size([2, 1]) + * c = a // b + * # result: tensor([[2], [3]]) torch.Size([2, 1]) + * result = c % input_shape_tensor + */ +class DecomposeAtenNonzeroOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNonzeroOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resultType = cast(op.getType()); + auto intType = resultType.getDtype(); + Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType); + auto constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + std::function makeOneElementList = [&](Value element) { + auto listType = Torch::ListType::get(element.getType()); + return rewriter.create(loc, listType, + ArrayRef{element}); + }; + + Value input = op.getSelf(); + auto inputType = dyn_cast(input.getType()); + int64_t inputRank = inputType.getSizes().size(); + + // t_flat = t.flatten() # torch.flatten(t, 0, 0) + int64_t flattenedSize = 1; + if (inputType.hasSizes()) { + for (auto size : inputType.getSizes()) { + flattenedSize *= size; + } + } else { + flattenedSize = kUnknownSize; + } + + auto flattendInputShape = SmallVector{flattenedSize}; + auto flattenedInputType = rewriter.getType( + flattendInputShape, inputType.getOptionalDtype()); + + // %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 : + auto inputDimsEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank - 1)); + Value flattenedInput = rewriter.create( + loc, flattenedInputType, input, constantZero /*inputDimsStart*/, + inputDimsEnd /*inputDimsEnd*/); + + // nonzero_mask = (t_flat != 0) + auto boolMaskType = inputType.getWithSizesAndDtype( + flattenedInputType.getOptionalSizes(), rewriter.getI1Type()); + Value boolMask = rewriter.create( + loc, boolMaskType, flattenedInput, constantZero); + + // nonzero_mask = nonzero_mask.int() + Value falseCst = rewriter.create(loc, false); + Value noneCst = rewriter.create(loc); + auto intMaskType = flattenedInputType.getWithSizesAndDtype( + flattenedInputType.getOptionalSizes(), intType); + Value intMask = rewriter.create( + loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst); + + // destination_indices = torch.cumsum(nonzero_mask, 0) - 1 + Value cumulativeSum = rewriter.create( + loc, intMaskType, intMask, constantZero, noneCst); + Value subtracted = rewriter.create( + loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne); + + // destination_indices = torch.clamp(destination_indices, min=0) + Value indices = rewriter.create(loc, intMaskType, + subtracted, constantZero); + + // iota = torch.arange(len(t_flat)) * nonzero_mask + Value end = rewriter.create(loc, flattenedInput, + /*dim=*/constantZero); + Value rangeTensor = rewriter.create( + loc, intMaskType, /*start*/ constantZero, /*end*/ end, + /*step*/ constantOne, noneCst, noneCst, noneCst, noneCst); + Value multiplied = rewriter.create(loc, intMaskType, + rangeTensor, intMask); + + // scatter_self = torch.zeros_like(t, dtype=torch.int64) + // AtenFullLike doesn't support index type so we have to use int. + Value zerosTensor = rewriter.create( + loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst, + noneCst, noneCst); + + // compacted = torch.scatter_add( + // scatter_self, dim=0, index=destination_indices_clamp, src=iota) + Value scatteredTensor = rewriter.create( + loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero, + /*index=*/indices, /*src=*/multiplied); + + // result_flat = compacted[:torch.sum(nonzero_mask)] + auto scalarType = ValueTensorType::get(rewriter.getContext(), + ArrayRef{}, intType); + Value sumMask = + rewriter.create(loc, scalarType, intMask, noneCst); + Value numNonzero = rewriter.create(loc, sumMask); + + auto slicedResultType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize}, intType); + Value slicedResult = + rewriter.create(loc, slicedResultType, + /*self=*/scatteredTensor, + /*dim=*/constantZero, + /*start=*/noneCst, + /*end=*/numNonzero, + /*step=*/constantOne); + + // TODO fix multidim dynamic support. The following code only work for + // static multidim. Convert flattened indices back to multi-dimensional + // indices original_shape = t.shape input_shape_tensor = + // torch.tensor(original_shape) + auto shapeType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{inputRank}, intType); + SmallVector shapeValues; + for (int i = 0; i < inputRank; i++) { + auto constantI = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + Value shape = rewriter.create(loc, input, + /*dim=*/constantI); + shapeValues.push_back(shape); + } + Value shapeTensorList = rewriter.create( + loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues); + Value inputShapeTensor = rewriter.create( + loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst); + + // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0) + Value flippedShape = rewriter.create( + loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); + Value cumulativeProduct = rewriter.create( + loc, shapeType, flippedShape, constantZero, noneCst); + Value flippedCumulativeProduct = rewriter.create( + loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); + + // strides = torch.cat([strides[1:-1], torch.tensor([1])]) + auto oneTensorType = ValueTensorType::get(rewriter.getContext(), + SmallVector{1}, intType); + Value oneTensor = rewriter.create( + loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst, + noneCst); + + Value strides; + if (inputRank > 1) { + // strides[1:-1] + auto slicedStrideType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{inputRank - 1}, // sizes + intType); + Value strideSliceEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank)); + Value slicedStrides = rewriter.create( + loc, slicedStrideType, /*self*/ flippedCumulativeProduct, + /*dim*/ constantZero, + /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); + // torch.cat + auto tensorListElementType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize}, intType); + Value tensorList = rewriter.create( + loc, Torch::ListType::get(tensorListElementType), + SmallVector{slicedStrides, oneTensor}); + strides = rewriter.create(loc, shapeType, tensorList, + constantZero); + } else { + // strides[1:-1] is empty + strides = oneTensor; + } + + // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % + // input_shape_tensor + auto unsqueezedResultType = ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize, 1}, intType); + Value unsqueezedResult = rewriter.create( + loc, unsqueezedResultType, slicedResult, constantOne); + + auto unsqueezedStridesType = ValueTensorType::get( + rewriter.getContext(), SmallVector{1, inputRank}, intType); + Value unsqueezedStrides = rewriter.create( + loc, unsqueezedStridesType, strides, constantZero); + + auto dividedBroadcastType = ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize, inputRank}, + intType); + Value divided = rewriter.create( + loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides); + + Value modded = rewriter.create( + loc, resultType, divided, inputShapeTensor); + + rewriter.replaceOp(op, modded); + return success(); + } +}; + // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { @@ -11263,6 +11497,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fe3aa3c5dd41..c266bf7ce8e5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -399,6 +399,7 @@ "AtenIntBoolOpModule_basic", "AtenIntMM_basic", "AtenItemFpOpModule_basic", + "AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", "QuantizedReluInt32_basic", @@ -628,6 +629,7 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", + "AtenNonzero1DDynamicModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenTopKModule_basic", @@ -3018,7 +3020,6 @@ "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", - "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index bd6f069ee9db..927bfe85df8a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6430,3 +6430,26 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils): module.forward( tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64) ) + + +# ============================================================================== + + +class AtenNonzero1DDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.ops.aten.nonzero(x) + + +@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule()) +def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) From 2f8dbca3f4bffab93845b0c1df28e5ef25ce09df Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 19 Dec 2024 14:35:04 -0800 Subject: [PATCH 13/17] [torch-mlir] add MPACT as an example torch-mlir based compiler (#3928) Rationale: In addition to IREE and Blade, MPACT provides an MLIR-based example of a PyTorch compiler that uses TORCH-MLIR. It also illustrates propagating sparsity from sparse PyTorch into MLIR, a feature that is not widespread in DL compilers yet. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 56371b949487..53b93e840ef3 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ Torch-MLIR is primarily a project that is integrated into compilers to bridge th * [IREE](https://github.com/iree-org/iree.git) * [Blade](https://github.com/alibaba/BladeDISC) +* [MPACT](https://github.com/MPACT-ORG/mpact-compiler) While most of the project is exercised via testing paths, there are some ways that an end user can directly use the APIs without further integration: From 13ee7c21fc70d891e37b511213b31dc842a5368d Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Thu, 19 Dec 2024 14:54:37 -0800 Subject: [PATCH 14/17] [TOSA] Add legalization for torch.aten.unfold (#3922) * Add Torch to TOSA legalization for torch.aten.unfold * Update e2e results in xfail_sets.py * Fix a minor detail in one of the unfold e2e tests * Add LIT tests for aten.unfold Change-Id: I6583019d1c2569bdaf9f0b67cf44b33067448af7 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 193 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 18 +- .../test_suite/reshape_like.py | 2 +- test/Conversion/TorchToTosa/basic.mlir | 50 +++++ 4 files changed, 251 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1c05ae49e18b..be51712a35de 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8256,6 +8256,198 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.unfold +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnfoldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Approach: Use GatherOp to retrieve target elements from target dim and then + // reshape the output into slices according to the output shape + // + // Lowering steps: + // 1. Create PyTorch-style indices tensor corresponding to target elements and + // reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1)) + // with d_x being the dimension size of the input at dim x. + // The indices vector will be calculated using the following formula: + // for i in range(d_0 * d_1 * ... * d_(target_dim - 1)): + // for window in range(nWindows): + // for elementIndex in range(size): + // for j in range(d_(target_dim + 1) * ... * d_(rank-1)): + // indices_vec.push_back(elementIndex + window * step) + // 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices + // 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + // 4. Reshape result from above to correct output shape + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + int64_t dim; + if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Only constant int dims are supported"); + + int64_t size; + if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) + return rewriter.notifyMatchFailure(op, + "Only constant int sizes are supported"); + + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure(op, + "Only constant int steps are supported"); + + if (step <= 0) + return rewriter.notifyMatchFailure(op, "Step value must be greater than 0"); + + // Handle rank zero + if (selfRank == 0) { + if (dim != 0) + return rewriter.notifyMatchFailure( + op, "Unsupported dim value for rank zero input"); + + if (size != 1) + return rewriter.notifyMatchFailure( + op, "Unsupported size value for rank zero input"); + + auto result = rewriter.create( + op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, + rewriter.getDenseI64ArrayAttr({1})); + + rewriter.replaceOp(op, {result.getResult()}); + return success(); + } + + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Dim value is invalid"); + + // Size of dimension 'dim' in the returned tensor (or number of windows within + // the dimension that got sliced) + int64_t nWindows = (selfShape[dim] - size) / step + 1; + + // Find number of times that each base index value gets repeated for target + // dim based on dim values before and after target dim i.e. preDimAccumulate = + // d_0 * d_1 * ... * d_(target_dim - 1) + // postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1) + int64_t preDimAccumulate = + std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1, + std::multiplies()); + int64_t postDimAccumulate = + std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1, + std::multiplies()); + + // Calculate PyTorch-style gather indices vector + // Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1 + // -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2 + // pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + // pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3, + // 0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + SmallVector pyTorchIndicesBaseVec; + SmallVector pyTorchIndicesVec; + + for (int64_t window = 0; window < nWindows; window++) { + for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) { + int32_t baseIndex = static_cast(elementIndex + window * step); + for (int64_t i = 0; i < postDimAccumulate; i++) + pyTorchIndicesBaseVec.push_back(baseIndex); + } + } + + for (int64_t i = 0; i < preDimAccumulate; i++) + pyTorchIndicesVec.insert(pyTorchIndicesVec.end(), + pyTorchIndicesBaseVec.begin(), + pyTorchIndicesBaseVec.end()); + + // Create the PyTorch-style indices tensor + // Continuing with the previous example: + // pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3) + // pyTorchIndices = tensor([[[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]], + // [[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]]]) + SmallVector pyTorchIndicesShape(selfShape); + pyTorchIndicesShape[dim] = nWindows * size; + auto pyTorchIndices = + tosa::getConstTensor(rewriter, op, pyTorchIndicesVec, + pyTorchIndicesShape) + .value(); + + // Convert PyTorch-style indices to TensorFlow-style indices + auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self, + pyTorchIndices, dim); + if (!tfIndices) + return rewriter.notifyMatchFailure(op, + "Convert PyTorch-style indices and dim " + "to TensorFlow-style indices failed"); + + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + auto gatherNdOp = tosa::convertGatherNdOp( + rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy), + self, tfIndices.value()); + if (!gatherNdOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + // Reshape to an intermediary shape where the gathered elements in dimension + // 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size' + SmallVector intermediaryShape; + for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) { + if (currentDim == dim) { + intermediaryShape.push_back(nWindows); + intermediaryShape.push_back(size); + } else { + intermediaryShape.push_back(pyTorchIndicesShape[currentDim]); + } + } + + auto reshapeOp = rewriter.create( + op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy), + gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape)); + + // Permute dims to the correct result order + SmallVector permutedDims; + for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) { + if (currentDim != dim + 1) + permutedDims.push_back(static_cast(currentDim)); + } + permutedDims.push_back(static_cast(dim + 1)); + + auto permutedDimsConst = tosa::getConstTensor( + rewriter, op, + /*vec=*/permutedDims, + /*shape=*/{static_cast(selfRank + 1)}) + .value(); + + auto result = rewriter.create( + op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -8617,6 +8809,7 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_ATENOP_PATTERN(AtenLog1pOp); INSERT_ATENOP_PATTERN(AtenLog10Op); INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenUnfoldOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c266bf7ce8e5..5b4385b9904b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1698,6 +1698,8 @@ "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", @@ -1706,6 +1708,9 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_basic", "ElementwiseErfIntModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseSigmoidIntModule_basic", @@ -3441,6 +3446,8 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "UniformModule_basic", + "UniformStaticShapeModule_basic", "AtenFftRfft2DLastDim_basic", "AtenFftRfft2DMiddleDim_basic", "IsInfiniteModule_basic", @@ -3460,11 +3467,7 @@ "MaxPool3dModule_basic", "MaxPool3dStaticModule_basic", "ViewDtypeStaticModule_basic", - "Unfold_Module_Dynamic_basic", - "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_Size_Zero_basic", - "Unfold_Module_Rank_Zero_basic", - "Unfold_Module_basic", "ArangeZeroElementOutputModule_basic", "NumpyTRank0Module_basic", "Permute0RankModule_basic", @@ -3888,17 +3891,10 @@ "AdaptiveAvgPool2dDynamic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", - "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexPutImpl2DFloatNonAccumulateModule_basic", - "IndexPutImpl3DFloatNonAccumulateModule_basic", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "OneHotModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index a8820f59c373..d1ddc42b39b1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1752,7 +1752,7 @@ def forward(self, x): return x.unfold(0, 0, 1) -@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero()) +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero_Size_Zero()) def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils): module.forward(tu.rand()) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 9e504c082a8c..a3d52166385a 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2943,3 +2943,53 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s } // ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x4xi32>) -> tensor<6x4x1xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_9]], %[[VAL_10]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_14:.*]] = tosa.gather %[[VAL_8]], %[[VAL_13]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[3,4,2],f32> +// CHECK: } +func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %0 = torch.aten.unfold %arg0, %int0, %int2, %int2 : !torch.vtensor<[6,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$rank_zero( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.unfold %arg0, %int0, %int1, %int1 : !torch.vtensor<[],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- From 02fa411801684962209744358c02dee090a7fb6f Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 19 Dec 2024 16:19:40 -0800 Subject: [PATCH 15/17] [torch-mlir][doc] remove MPACT as example (#3930) Per Stella's request --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 53b93e840ef3..56371b949487 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,6 @@ Torch-MLIR is primarily a project that is integrated into compilers to bridge th * [IREE](https://github.com/iree-org/iree.git) * [Blade](https://github.com/alibaba/BladeDISC) -* [MPACT](https://github.com/MPACT-ORG/mpact-compiler) While most of the project is exercised via testing paths, there are some ways that an end user can directly use the APIs without further integration: From a6179c076bd986472c9b8c5aab591c8ad3d33043 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 20 Dec 2024 11:28:23 +0530 Subject: [PATCH 16/17] build: manually update PyTorch version (#3919) This commit sets the PyTorch and TorchVision version to nightly release 2024-12-16. This commit adds the support for `aten.rrelu_with_noise_functional` op by decomposing it. And, also updates the existing decomposition of `aten.rrelu_with_noise` op by decomposing it to the newly added `aten.rrelu_with_noise_functional` op. It also updates the e2e tests for `aten.rrelu_with_noise` op by replacing it with its functional variant which is added here: https://github.com/pytorch/pytorch/commit/f85e23818618d43351f24e38dd7aacb40543ba0e and which captures the noise mutation which was earlier a reason for the test failures during the training mode. This commit also removes the newly passing tests from the xfail_sets. --------- Signed-off-by: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 33 ++++++++++- .../Transforms/AbstractInterpLibrary.cpp | 56 ++++++------------- .../Torch/Transforms/DecomposeComplexOps.cpp | 39 ++++++++++--- .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 21 ------- .../build_tools/abstract_interp_lib_gen.py | 17 ++++-- .../build_tools/torch_ods_gen.py | 3 + .../test_suite/elementwise.py | 32 ++++++++--- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 11 files changed, 121 insertions(+), 87 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 556b0aa76e93..ff1ffd7e2b62 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -310,9 +310,7 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ } def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly + AllowsTypeRefinement ]> { let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; let arguments = (ins @@ -17519,6 +17517,35 @@ def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backwar }]; } +def Torch_AtenRreluWithNoiseFunctionalOp : Torch_Op<"aten.rrelu_with_noise_functional", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$noise_out + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseFunctionalOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); + } + void AtenRreluWithNoiseFunctionalOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); + } + }]; +} + def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index fb0aaa7201b8..5fd05708961c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7304,6 +7304,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12599,17 +12605,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %true = torch.constant.bool true\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If.yield %3 : !torch.bool\n" -" }\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -12618,46 +12622,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %true = torch.constant.bool true\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If.yield %7 : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.prim.If %4 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If.yield %7 : !torch.bool\n" -" }\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %0#1 : !torch.int\n" +" %3 = torch.prim.TupleConstruct %0#1, %1#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 24eb589cc397..9c2a80187c93 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3791,11 +3791,7 @@ class DecomposeAtenRreluOp : public OpRewritePattern { // Create a uniform random op with low and high set to `lower` and // `upper`, respectively. Value none = rewriter.create(loc); - Value emptyTensor = rewriter.create( - loc, resType, self, constantZeroFloat, /*dtype=*/none, - /*layout=*/none, - /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); - alpha = rewriter.create(loc, resType, emptyTensor, + alpha = rewriter.create(loc, resType, self, /*from=*/lower, /*to=*/upper, /*generator=*/none); } else { @@ -3840,6 +3836,33 @@ class DecomposeAtenRreluWithNoiseOp Value lower = op.getLower(); Value upper = op.getUpper(); auto resType = cast(op.getType()); + Value cstNone = rewriter.create(loc); + Value cstFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + Value result = + rewriter + .create( + loc, resType, self, noise, lower, upper, cstFalse, cstNone) + ->getResult(0); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenRreluWithNoiseFunctionalOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseFunctionalOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getResultTypes()[0]); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -3885,7 +3908,7 @@ class DecomposeAtenRreluWithNoiseOp rewriter.getI1Type()); Value oneTensor = createRank0Tensor(rewriter, loc, resType, constantOneFloat); - Value not_positive = rewriter.create( + Value not_positive = rewriter.create( loc, boolResType, self, constantZeroFloat); noise = rewriter.create(loc, resType, not_positive, alpha, oneTensor); @@ -3897,7 +3920,7 @@ class DecomposeAtenRreluWithNoiseOp rewriter.create(loc, resType, zeroTensor, scaledSelf); Value rreluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOneFloat); - rewriter.replaceOp(op, rreluOutput); + rewriter.replaceOp(op, {rreluOutput, noise}); return success(); } }; @@ -11568,6 +11591,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 25635d2c5c46..f15911e2b5ba 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -501,6 +501,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5b4385b9904b..bb8f3a029b1d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -398,7 +398,6 @@ "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", "AtenIntMM_basic", - "AtenItemFpOpModule_basic", "AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", @@ -425,7 +424,6 @@ "CumsumModule_basic", "CumprodModule_basic", "DeformConv2D_basic", - "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", @@ -439,7 +437,6 @@ "IntFloatModule_basic", "IntImplicitModule_basic", "LenStrModule_basic", - "MulFloatModule_basic", "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", "NllLossModuleBackward1DMeanWeight_basic", @@ -464,15 +461,11 @@ "QuantizedSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ScalarImplicitFloatModule_basic", "SplitDimDynamicModule_basic", "SplitDimStaticModule_basic", "SqrtIntModule_basic", - "SubFloatModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", - "TensorToFloatZeroRank_basic", - "TensorToFloat_basic", "ThresholdBackward2dMixedModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dDynamicFactor_basic", @@ -507,9 +500,6 @@ "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - # RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSignbitModule_basic", "ElementwiseCopysignModule_basic", "BernoulliFloatModule_basic", @@ -527,9 +517,6 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", - # torch export: RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -934,9 +921,6 @@ "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", - # RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", "BernoulliFloatModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", @@ -961,9 +945,6 @@ "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", - # torch export: RuntimeError: cannot mutate tensors with frozen storage - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", } @@ -3459,8 +3440,6 @@ "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", - "ElementwiseRreluWithNoiseTrainModule_basic", - "ElementwiseRreluWithNoiseTrainStaticModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2a980bf534fd..a73d188d7168 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -649,6 +649,9 @@ def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0 def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇rrelu_with_noise_functional〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[List[int], List[int]]: + return upstream_shape_functions.unary(self), upstream_shape_functions.unary(noise) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -3472,21 +3475,25 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype - assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype noise_rank, noise_dtype = noise_rank_dtype - assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) - assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype) assert self_rank == noise_rank return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇rrelu_with_noise_functional〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + noise_rank, noise_dtype = noise_rank_dtype + assert self_rank == noise_rank + return self_dtype, noise_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 4c2de094e109..930979b3c939 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1212,6 +1212,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)" ) + emit( + "aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)" + ) # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index b1745fa5b85a..3ee851611ac0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1240,13 +1240,20 @@ def __init__(self): [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] ) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True) - return torch.mean(res), torch.std(res) + out, out_noise = torch.ops.aten.rrelu_with_noise_functional( + x, noise, 0.2, 0.5, True + ) + return ( + torch.mean(out), + torch.std(out), + torch.mean(out_noise), + torch.std(out_noise), + ) @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + module.forward(tu.rand(256, 256, low=-1, high=1), tu.rand(256, 256)) # ============================================================================== @@ -1258,16 +1265,23 @@ def __init__(self): @export @annotate_args( - [None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)] + [None, ([256, 256], torch.float32, True), ([256, 256], torch.float32, True)] ) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) - return torch.mean(res), torch.std(res) + out, out_noise = torch.ops.aten.rrelu_with_noise_functional( + x, noise, 0.4, 0.6, True + ) + return ( + torch.mean(out), + torch.std(out), + torch.mean(out_noise), + torch.std(out_noise), + ) @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + module.forward(tu.rand(256, 256, low=-1, high=1), tu.rand(256, 256)) # ============================================================================== @@ -1282,7 +1296,7 @@ def __init__(self): [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] ) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + res = torch.ops.aten.rrelu_with_noise_functional(x, noise, 0.4, 0.6, False)[0] return torch.mean(res), torch.std(res) @@ -1301,7 +1315,7 @@ def __init__(self): @export @annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)]) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + res = torch.ops.aten.rrelu_with_noise_functional(x, noise, 0.4, 0.6, False)[0] return torch.mean(res), torch.std(res) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index ae415d496d6d..0439f8244a0b 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -798d5b7ddd08899fb62672d56044dbf1f63a4d17 +3f159d635772fa2a8fd352d96b95100d885f8169 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 83ecc622c492..7ab5a78d074f 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.6.0.dev20241201 +torch==2.6.0.dev20241216 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index e0583c31e56c..be1615525984 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.20.0.dev20241201 +torchvision==0.22.0.dev20241216 From 38a0a5a6c7935f171f9900d55906e7b5c865b88c Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Mon, 23 Dec 2024 14:02:56 -0500 Subject: [PATCH 17/17] Fix output size computation for MaxPool2D for ceil_model = true. (#3890) This PR fixes the output size computation as per https://github.com/pytorch/pytorch/blob/d8c14838f164ee02b88b6e37471b71bb0373f865/torch/_meta_registrations.py#L3847 ``` if ceil_mode: if (outputSize - 1) * stride >= inputSize + pad_l: outputSize -= 1 return outputSize ``` --- lib/Conversion/TorchToLinalg/Utils.cpp | 16 ++++++++++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 8 +++-- projects/pt1/e2e_testing/xfail_sets.py | 16 ++++++++++ .../torch_mlir_e2e_test/test_suite/pooling.py | 29 +++++++++++++++++++ 4 files changed, 66 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index cf41bbcd711b..98dbc1957892 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -116,6 +116,22 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, else division = b.createOrFold(loc, dividend, strideInt); Value out = b.createOrFold(loc, division, c1); + + if (ceilMode) { + Value outMinusOneTimesStride = + b.createOrFold(loc, division, strideInt); + Value inAddLeftPadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), paddingInt); + + auto reduceOutputDimCond = + b.createOrFold(loc, arith::CmpIPredicate::uge, + outMinusOneTimesStride, inAddLeftPadding); + + auto reducedDim = b.createOrFold(loc, reduceOutputDimCond, + division, out); + return castIntToIndex(b, loc, reducedDim); + } + return castIntToIndex(b, loc, out); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index be51712a35de..1c2f7d6f2a11 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5398,9 +5398,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { } else { int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; - if (ceilMode && (dimSize % stride != 0)) - return dimSize / stride + 2; - return dimSize / stride + 1; + int64_t outputDim = dimSize / stride + 1; + if (ceilMode && (dimSize % stride != 0) && + (outputDim * stride < inputDim + padBefore)) + outputDim++; + return outputDim; } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bb8f3a029b1d..1dce55f06158 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -735,6 +735,7 @@ "LenStrModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", "MaxPool2dWithIndicesBackwardStatic3DModule_basic", @@ -2255,6 +2256,7 @@ "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dStaticModule_basic", "MeanModule_basic", "MmDagModule_basic", @@ -3380,6 +3382,13 @@ "ScaledDotProductAttentionBoolMaskModule_basic", } +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } + if torch_version_for_comparison() < version.parse("2.4.0.dev"): STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { "AtenIntMM_basic", @@ -4932,3 +4941,10 @@ "_LogSoftmaxModule_basic", "_SoftmaxModule_basic", } + +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_TOSA_XFAIL_SET = ONNX_TOSA_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 84e0e2eb9cf5..e2eaa4cfd0fe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) +class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp2d = torch.nn.MaxPool2d( + kernel_size=6, + stride=6, + padding=3, + dilation=1, + ceil_mode=True, + ) + + @export + @annotate_args( + [ + None, + ([2, 6, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp2d(x) + + +@register_test_case( + module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule() +) +def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0)) + + # ==============================================================================