Skip to content

[mlir][tensor] Make getMixedPadImpl return static values when possible. #85016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Mar 13, 2024

If low and high are constants (i.e., not attributes), users still prefer attributes. Otherwise, there could be failures in type inference. A failure is introduced by 60e562d, see the drop_known_unit_constant_low_high test for more details.

If low and high are constants (i.e., not attributes), users still prefer
attributes. Otherwise, there could be failures in type inference.
@llvmbot
Copy link
Member

llvmbot commented Mar 13, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

If low and high are constants (i.e., not attributes), users still prefer attributes. Otherwise, there could be failures in type inference. A failure is introduced by 60e562d, see the drop_known_unit_constant_low_high test for more details.


Full diff: https://github.com/llvm/llvm-project/pull/85016.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+1-1)
  • (modified) mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir (+1-2)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+20)
  • (modified) mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir (+1-2)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 670202fe4372e6..cf7f3e89079c1c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1364,7 +1364,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
       unsigned count = staticAttrs.size();
       for (unsigned idx = 0; idx < count; ++idx) {
         if (ShapedType::isDynamic(staticAttrs[idx]))
-          res.push_back(values[numDynamic++]);
+          res.push_back(getAsOpFoldResult(values[numDynamic++]));
         else
           res.push_back(builder.getI64IntegerAttr(staticAttrs[idx]));
       }
diff --git a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
index 238c0c51312a6b..a0a676edceb745 100644
--- a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
+++ b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
@@ -22,7 +22,6 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
 // CHECK-LABEL:   func @generalize_pad_tensor_dynamic_shape(
 // CHECK-SAME:                                              %[[IN:.*]]: tensor<4x?x2x?xf32>,
 // CHECK-SAME:                                              %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
@@ -33,7 +32,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
 // CHECK:           %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index
 // CHECK:           %[[INIT:.*]] = tensor.empty(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : tensor<4x?x?x?xf32>
 // CHECK:           %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32>
-// CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
+// CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 0, %[[OFFSET]], 0] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
 // CHECK:           return %[[PADDED]] : tensor<4x?x?x?xf32>
 // CHECK:         }
 func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> {
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index f2c490b832076f..c140b6abcc37a2 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1033,3 +1033,23 @@ func.func @do_not_drop_non_constant_padding(%arg0: tensor<1x1x3x1x1xf32>, %pad:
 // CHECK-SLICES-LABEL: func @do_not_drop_non_constant_padding
 //       CHECK-SLICES:   tensor.pad %{{.*}} low[0, 1, 0, %c0, 0] high[0, 0, 0, %c0, 2]
 //       CHECK-SLICES:   } : tensor<1x1x3x1x1xf32> to tensor<1x2x3x1x3xf32>
+
+// -----
+
+func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> tensor<1x384x128xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[%c0, %c1, %c0] high[%c0, %c0, %c0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index):
+    tensor.yield %cst : f32
+  } : tensor<1x383x128xf32> to tensor<1x384x128xf32>
+  return %padded : tensor<1x384x128xf32>
+}
+// CHECK-LABEL: func @drop_known_unit_constant_low_high
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
+//  CHECK-SAME:     {{\[}}[0, 1], [2]] : tensor<1x383x128xf32> into tensor<383x128xf32>
+//       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0] high[0, 0]
+//       CHECK:   } : tensor<383x128xf32> to tensor<384x128xf32>
+//       CHECK:   tensor.expand_shape %[[PADDED]]
+//  CHECK-SAME:     {{\[}}[0, 1], [2]] : tensor<384x128xf32> into tensor<1x384x128xf32>
diff --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
index ac0eb48fb37940..2beab31b613d54 100644
--- a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
@@ -19,7 +19,6 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
 // CHECK-LABEL:   func @generalize_pad_tensor_dynamic_shape(
 // CHECK-SAME:                                              %[[IN:.*]]: tensor<4x?x2x?xf32>,
 // CHECK-SAME:                                              %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
@@ -32,7 +31,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
 // CHECK:           %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<4x?x?x?xf32>) -> tensor<4x?x?x?xf32>
 // CHECK:           %[[DIM1_1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
 // CHECK:           %[[DIM3_1:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
-// CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
+// CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 0, %[[OFFSET]], 0] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
 // CHECK:           return %[[PADDED]] : tensor<4x?x?x?xf32>
 // CHECK:         }
 func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> {

hanhanW added a commit to iree-org/iree that referenced this pull request Mar 13, 2024
@hanhanW hanhanW merged commit bb82092 into llvm:main Mar 13, 2024
@hanhanW hanhanW deleted the pad-get-more-static-low-high branch March 13, 2024 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants