Skip to content

[mlir][tensor] Fold padding_value away for pack ops when possible. #74005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 1, 2023

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Nov 30, 2023

If we can infer statically that there are no incomplete tiles, we can remove the optional padding operand.

Fixes iree-org/iree#15417

If we can infer statically that there are no incomplete tiles, we can
remove the optional padding operand.
@llvmbot
Copy link
Member

llvmbot commented Nov 30, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Han-Chung Wang (hanhanW)

Changes

If we can infer statically that there are no incomplete tiles, we can remove the optional padding operand.

Fixes iree-org/iree#15417


Full diff: https://github.com/llvm/llvm-project/pull/74005.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+30-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+49)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5bfcb35127b5267..30b719924dea56f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3800,17 +3800,37 @@ static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
   return true;
 }
 
-/// Fold an unpack(pack(x)) to x.
+/// Returns true if the pack op does not need a padding value.
+static bool paddingIsNotNeeded(PackOp op) {
+  auto srcType = op.getSourceType();
+  if (llvm::any_of(op.getInnerDimsPos(),
+                   [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
+    return false;
+  return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
+                                      op.getMixedTiles());
+}
+
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
-  UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
-  if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
-    return failure();
-  if (packOp.getPaddingValue() ||
-      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
-      !haveSameTiles(packOp, unPackOp))
-    return failure();
-  rewriter.replaceOp(packOp, unPackOp.getSource());
-  return success();
+  // Fold an unpack(pack(x)) to x.
+  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
+    if (unPackOp.getSourceType() != packOp.getDestType())
+      return failure();
+    if (packOp.getPaddingValue() ||
+        !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+        !haveSameTiles(packOp, unPackOp))
+      return failure();
+    rewriter.replaceOp(packOp, unPackOp.getSource());
+    return success();
+  }
+
+  // Fold optional PaddingValue operand away if padding is not needed.
+  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
+    rewriter.startRootUpdate(packOp);
+    packOp.getPaddingValueMutable().clear();
+    rewriter.finalizeRootUpdate(packOp);
+    return success();
+  }
+  return failure();
 }
 
 template <typename PackOrUnpackOp>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 41bfd6fe7b6eedc..68ce35c866eb034 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -719,6 +719,55 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
 
 // -----
 
+func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
+  %pack = tensor.pack %arg0
+    padding_value(%cst : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 1]
+    into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
+  return %pack : tensor<31250x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack
+// CHECK-NOT:     padding_value
+
+// -----
+
+func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
+  %pack = tensor.pack %arg0
+    padding_value(%cst : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 1]
+    into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
+  return %pack : tensor<31250x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative1
+// CHECK:         tensor.pack
+// CHECK-SAME:      padding_value
+
+// -----
+
+func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = tensor.pack %arg0
+    padding_value(%cst : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 1]
+    into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
+  return %pack : tensor<?x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative2
+// CHECK:         tensor.pack
+// CHECK-SAME:      padding_value
+
+// -----
+
 // CHECK-LABEL: func @fold_unpack_constant_splat
 //   CHECK-NOT: tensor.unpack
 //       CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>

@hanhanW hanhanW requested a review from chelini December 1, 2023 00:02
Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks good to me.

@hanhanW hanhanW merged commit 171cac9 into llvm:main Dec 1, 2023
@hanhanW hanhanW deleted the pack-fold-pad branch December 1, 2023 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fold padding_value away if a tensor.pack op does not have incomplete tile
3 participants