Skip to content

[mlir][linalg] Bugfix for InlineScalarOperands #111534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 14, 2024
Merged

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Oct 8, 2024

This PR fixes a bug where scalarOperand is a simple scalar and should be used directly, rather than accessed via tensor.extract. Fixes #111243.

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a bug where scalarOperand is a simple scalar and should be used directly, rather than accessed via tensor.extract. Fixes #111243.


Full diff: https://github.com/llvm/llvm-project/pull/111534.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp (+5-3)
  • (modified) mlir/test/Dialect/Linalg/inline-scalar-operands.mlir (+24)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 6db51f4b84d112..a8b46905733b8c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -78,9 +78,11 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
       for (auto idx : indices)
         indicesValues.emplace_back(
             rewriter.create<arith::ConstantIndexOp>(loc, idx));
-      Value extractedValue = rewriter.create<tensor::ExtractOp>(
-          loc, opOperand->get(), indicesValues);
-      body->getArgument(idx).replaceAllUsesWith(extractedValue);
+      Value scalarValue = opOperand->get();
+      if (isa<RankedTensorType>(scalarValue.getType()))
+        scalarValue =
+            rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
+      body->getArgument(idx).replaceAllUsesWith(scalarValue);
       body->eraseArgument(idx);
     }
 
diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
index 93d5b8779c7461..8384b307d2dfbd 100644
--- a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
+++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir
@@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4
     } -> tensor<4xf32>
   return %1 : tensor<4xf32>
 }
+
+// -----
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0) -> (d0)>
+#map3 = affine_map<(d0) -> ()>
+
+// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32)
+func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> {
+    %0 = tensor.empty() : tensor<4xf32>
+    // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]],
+    // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>)
+    %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2],
+                         iterator_types = ["parallel"]}
+                         ins(%arg0, %scalar : tensor<4xf32>, f32)
+                         outs(%0 : tensor<4xf32>) {
+    // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32)
+    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+      // CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32
+      %2 = arith.divf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}

@CoTinker
Copy link
Contributor Author

Ping~

This PR fixes a bug where `scalarOperand` is a simple scalar and should be
used directly, rather than accessed via `tensor.extract`.
@CoTinker CoTinker merged commit 4b31568 into llvm:main Oct 14, 2024
8 checks passed
@CoTinker CoTinker deleted the fix_extract branch October 14, 2024 07:38
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
This PR fixes a bug where `scalarOperand` is a simple scalar and should
be used directly, rather than accessed via `tensor.extract`. Fixes
llvm#111243.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Crash when using --linalg-inline-scalar-operands
3 participants