Skip to content

[mlir] Vectorize tensor.pad with low padding for unit dims #133808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 2, 2025

Conversation

nirvedhmeshram
Copy link
Contributor

@nirvedhmeshram nirvedhmeshram commented Mar 31, 2025

We currently do not have masked vectorization support for tenor.pad with low padding. However, we can allow this in the special case where the result dimension after padding is a unit dim. The reason is when we actually have a low pad on a unit dim, the input size of that dimension will be (or should be for correct IR) dynamically zero and hence we will create a zero mask which is correct. If the low pad is dynamically zero then the lowering is correct as well.

@llvmbot
Copy link
Member

llvmbot commented Mar 31, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

We currently do not have masked vectorization support for tenor.pad with low padding. However, we can allow this in the special case where the result dimension after padding is a unit dim. The reason is when we actually have a low pad on a unit dim, the input size of that dimension will be (or should be for correct IR) dynamically be zero and hence we will create a zero mask which is correct. If the low pad is dynamically zero then the lowering is correct as well.


Full diff: https://github.com/llvm/llvm-project/pull/133808.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+7-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization-unsupported.mlir (+24)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+40)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2dcd897330d1e..4adacbe7d45ca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2178,11 +2178,14 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
                                               inputVectorSizes)))
     return failure();
 
-  if (llvm::any_of(padOp.getLow(), [](Value v) {
-        std::optional<int64_t> res = getConstantIntValue(v);
-        return !res.has_value() || res.value() != 0;
+  if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
+        Value padValue = en.value();
+        unsigned pos = en.index();
+        std::optional<int64_t> res = getConstantIntValue(padValue);
+        return (!res.has_value() || res.value() != 0) &&
+               resultTensorShape[pos] != 1;
       })) {
-    LDBG("low pad must all be zero: " << padOp << "\n");
+    LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 2d1f0191eb798..f419d81d8df2b 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -305,6 +305,30 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @test_masked_vectorize_lowpad(
+  %0 : tensor<?x?xf32>, %h0 : index, %h1 : index, %l0 : index)
+    -> tensor<2x4xf32> {
+  // expected-error @+3 {{Attempted to vectorize, but failed}}
+  %cst = arith.constant 42.43 : f32
+  %c0 = arith.constant 0 : index
+  %1 = tensor.pad %0 low[%l0, %c0] high[%h0, %h1]  {
+    ^bb0(%hh1: index, %hh2: index):
+      tensor.yield %cst : f32
+    } : tensor<?x?xf32> to tensor<2x4xf32>
+  return %1: tensor<2x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // With dynamically shaped source, the vectorizer infers the vector size for
 // xfer Ops from the destination tensor and, conservatively, assumes
 // out-of-bounds accesses. Out-of-bounds accesses require a pad value, but
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c6d9ec6215715..efd752e70df03 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -666,6 +666,46 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+//  CHECK-LABEL: func @test_masked_vectorize_unit_lowpad
+func.func @test_masked_vectorize_unit_lowpad(
+  %0 : tensor<?x?xf32>, %h0 : index, %h1 : index, %l0 : index)
+    -> tensor<1x4xf32>
+{
+  //  CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
+  //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  //      CHECK: %[[c0_1:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
+  //  CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
+  //      CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<1x4xi1>
+  //      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+  // CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[c42]]
+  // CHECK-SAME:   {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32>
+  // CHECK-SAME: } : vector<1x4xi1> -> vector<1x4xf32>
+  //  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<1x4xf32>
+  //  CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+  //      CHECK: %[[masked_write:.*]] = vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
+  // CHECK-SAME:   {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+  //      CHECK: return %[[masked_write]] : tensor<1x4xf32>
+  %cst = arith.constant 42.43 : f32
+  %c0 = arith.constant 0 : index
+  %1 = tensor.pad %0 low[%l0, %c0] high[%h0, %h1]  {
+    ^bb0(%hh1: index, %hh2: index):
+      tensor.yield %cst : f32
+    } : tensor<?x?xf32> to tensor<1x4xf32>
+  return %1: tensor<1x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // Input identical as the test in vectorization-with-patterns.mlir. Output is
 // different - vector sizes are inferred (rather than user-specified) and hence
 // masking was used.

@nirvedhmeshram nirvedhmeshram changed the title [mlir] Vectorize tenosr.pad with low padding for unit dims [mlir] Vectorize tensor.pad with low padding for unit dims Mar 31, 2025
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool that the change is quite simple! LGTM, but please wait for other reviewers to get a chance to look at it too.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just some minor asks :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just some minor asks :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tidying this up!

I've left a few optional nits/nice-to-haves, but this already LGTM, hence approving as is.

Signed-off-by: Nirvedh <[email protected]>
@nirvedhmeshram nirvedhmeshram merged commit 42b3f91 into llvm:main Apr 2, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants