Skip to content

[mlir][linalg] Handle reassociationIndices correctly for 0D tensor #121683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2025

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Jan 5, 2025

This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043.

@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a bug where a value is assigned to a 0-sized reassociationIndices, preventing a crash. Fixes #116043.


Full diff: https://github.com/llvm/llvm-project/pull/121683.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+7-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+23)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..ac4078a9ffe0cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -611,10 +611,13 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
   SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
       shapedType.getRank());
   int64_t index = 0;
-  for (index = 0; index <= numExtraDims; index++)
-    reassociationIndices[0].push_back(index);
-  for (size_t position = 1; position < reassociationIndices.size(); position++)
-    reassociationIndices[position].push_back(index++);
+  if (shapedType.getRank() != 0) {
+    for (index = 0; index <= numExtraDims; index++)
+      reassociationIndices[0].push_back(index);
+    for (size_t position = 1; position < reassociationIndices.size();
+         position++)
+      reassociationIndices[position].push_back(index++);
+  }
 
   // Compute result type
   SmallVector<int64_t> resultShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..651d773f729396 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1964,3 +1964,26 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
   %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
   return %0: tensor<1xi64>
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @test_add_0d_broadcast(
+// CHECK-SAME:                                     %[[ARG0:.*]]: tensor<2x1xf32>,
+// CHECK-SAME:                                     %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
+// CHECK:           %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+// CHECK:           %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
+// CHECK:           %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
+// CHECK:           ^bb0(%[[IN:.*]]: f32, %[[IN_0.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:             %[[ADD:.*]] = arith.addf %[[IN]], %[[IN_0]] : f32
+// CHECK:             linalg.yield %[[ADD]] : f32
+// CHECK:           } -> tensor<2x1xf32>
+// CHECK:           return %[[RESULT]] : tensor<2x1xf32>
+// CHECK:         }
+func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
+  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
+  return %0 : tensor<2x1xf32>
+}

@CoTinker CoTinker force-pushed the expand_rank branch 2 times, most recently from b7403bd to 6fa17ac Compare January 5, 2025 09:03
Copy link
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thank you!

LGTM % some minor comments.

This PR fixes a bug where a value is assigned to a 0-sized
reassociationIndices, preventing a crash.
@CoTinker CoTinker merged commit 9190e1c into llvm:main Jan 10, 2025
8 checks passed
@CoTinker CoTinker deleted the expand_rank branch January 10, 2025 01:23
BaiXilin pushed a commit to BaiXilin/llvm-fix-vnni-instr-types that referenced this pull request Jan 12, 2025
…lvm#121683)

This PR fixes a bug where a value is assigned to a 0-sized
reassociationIndices, preventing a crash. Fixes llvm#116043.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR]-pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg))" triggers Assertion Failure `idx < size()'
5 participants