Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Handle reassociationIndices correctly for 0D tensor #121683

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Jan 5, 2025

This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043.

@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043.


Full diff: https://github.com/llvm/llvm-project/pull/121683.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+7-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+23)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..ac4078a9ffe0cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -611,10 +611,13 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
   SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
       shapedType.getRank());
   int64_t index = 0;
-  for (index = 0; index <= numExtraDims; index++)
-    reassociationIndices[0].push_back(index);
-  for (size_t position = 1; position < reassociationIndices.size(); position++)
-    reassociationIndices[position].push_back(index++);
+  if (shapedType.getRank() != 0) {
+    for (index = 0; index <= numExtraDims; index++)
+      reassociationIndices[0].push_back(index);
+    for (size_t position = 1; position < reassociationIndices.size();
+         position++)
+      reassociationIndices[position].push_back(index++);
+  }
 
   // Compute result type
   SmallVector<int64_t> resultShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..651d773f729396 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1964,3 +1964,26 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
   %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
   return %0: tensor<1xi64>
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @test_add_0d_broadcast(
+// CHECK-SAME:                                     %[[ARG0:.*]]: tensor<2x1xf32>,
+// CHECK-SAME:                                     %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
+// CHECK:           %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+// CHECK:           %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
+// CHECK:           %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
+// CHECK:           ^bb0(%[[IN:.*]]: f32, %[[IN_0.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:             %[[ADD:.*]] = arith.addf %[[IN]], %[[IN_0]] : f32
+// CHECK:             linalg.yield %[[ADD]] : f32
+// CHECK:           } -> tensor<2x1xf32>
+// CHECK:           return %[[RESULT]] : tensor<2x1xf32>
+// CHECK:         }
+func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
+  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
+  return %0 : tensor<2x1xf32>
+}

@CoTinker CoTinker force-pushed the expand_rank branch 2 times, most recently from b7403bd to 6fa17ac Compare January 5, 2025 09:03
Copy link
Contributor

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thank you!

LGTM % some minor comments.

Comment on lines 1986 to 1989
func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
return %0 : tensor<2x1xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure all tests for tosa.add are clustered together.

reassociationIndices[0].push_back(index);
for (size_t position = 1; position < reassociationIndices.size(); position++)
reassociationIndices[position].push_back(index++);
if (shapedType.getRank() != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Perhaps it's just me, but shapedType is a very enigmatic name and doesn't add much. I'd rather rename it to e.g. srcTensorType. This is unrelated to this PR, but I would welcome an update :)

Copy link
Contributor

@GeorgeARM GeorgeARM Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I find srcTensorType cleaner. Prefixes like src in my mind are associated with some kind of value positional information which is not the case here. In general, I find the current naming ok.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good name is the best documentation :)

srcTensorType might not be perfect, but IMHO, it's still better than something like:

ShapedType shapedType;

That's similar to:

int Int;
float Float;

We're missing an opportunity to add meaningful context to the variable name here. Instead, we're repeating information already conveyed by the type name.

From my perspective, when I see shapedType in GitHub (which lacks proper code highlighting), it can be cognitively taxing - am I looking at a variable or a type? While this is less of an issue in a proper editor that distinguishes between the two, it would be helpful if variable naming optimised for code reviews on GitHub as well.

Perhaps tensorType?

Anyway, like I mentioned earlier, this is just a nice-to-have.

Copy link
Contributor

@GeorgeARM GeorgeARM Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorType sounds like a good alternative!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

This PR fixes a bug where a value is assigned to a 0-sized
reassociationIndices, preventing a crash.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants