-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Handle reassociationIndices correctly for 0D tensor #121683
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043. Full diff: https://github.com/llvm/llvm-project/pull/121683.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..ac4078a9ffe0cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -611,10 +611,13 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
shapedType.getRank());
int64_t index = 0;
- for (index = 0; index <= numExtraDims; index++)
- reassociationIndices[0].push_back(index);
- for (size_t position = 1; position < reassociationIndices.size(); position++)
- reassociationIndices[position].push_back(index++);
+ if (shapedType.getRank() != 0) {
+ for (index = 0; index <= numExtraDims; index++)
+ reassociationIndices[0].push_back(index);
+ for (size_t position = 1; position < reassociationIndices.size();
+ position++)
+ reassociationIndices[position].push_back(index++);
+ }
// Compute result type
SmallVector<int64_t> resultShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..651d773f729396 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1964,3 +1964,26 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
%0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
return %0: tensor<1xi64>
}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @test_add_0d_broadcast(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
+// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
+// CHECK: ^bb0(%[[IN:.*]]: f32, %[[IN_0.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN_0]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: } -> tensor<2x1xf32>
+// CHECK: return %[[RESULT]] : tensor<2x1xf32>
+// CHECK: }
+func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
+ %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
+ return %0 : tensor<2x1xf32>
+}
|
b7403bd
to
6fa17ac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thank you!
LGTM % some minor comments.
func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> { | ||
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32> | ||
return %0 : tensor<2x1xf32> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure all tests for tosa.add
are clustered together.
reassociationIndices[0].push_back(index); | ||
for (size_t position = 1; position < reassociationIndices.size(); position++) | ||
reassociationIndices[position].push_back(index++); | ||
if (shapedType.getRank() != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Perhaps it's just me, but shapedType
is a very enigmatic name and doesn't add much. I'd rather rename it to e.g. srcTensorType
. This is unrelated to this PR, but I would welcome an update :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I find srcTensorType
cleaner. Prefixes like src
in my mind are associated with some kind of value positional information which is not the case here. In general, I find the current naming ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A good name is the best documentation :)
srcTensorType
might not be perfect, but IMHO, it's still better than something like:
ShapedType shapedType;
That's similar to:
int Int;
float Float;
We're missing an opportunity to add meaningful context to the variable name here. Instead, we're repeating information already conveyed by the type name.
From my perspective, when I see shapedType
in GitHub (which lacks proper code highlighting), it can be cognitively taxing - am I looking at a variable or a type? While this is less of an issue in a proper editor that distinguishes between the two, it would be helpful if variable naming optimised for code reviews on GitHub as well.
Perhaps tensorType
?
Anyway, like I mentioned earlier, this is just a nice-to-have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensorType
sounds like a good alternative!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash.
This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043.