Skip to content

[mlir][tensor] Fix tensor::PackOp fold() handling of padding value #87296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Apr 2, 2024

We can't just check if it is a splat constant or not. We should also check if the value match.

We can't just check if it is a splat constant or not. We should also
check if the value match.
@llvmbot
Copy link
Member

llvmbot commented Apr 2, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Han-Chung Wang (hanhanW)

Changes

We can't just check if it is a splat constant or not. We should also check if the value match.


Full diff: https://github.com/llvm/llvm-project/pull/87296.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+9-4)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+33)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 38a9ad60bb7948..8dc1ef67ce65c5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1069,9 +1069,11 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// Try to remove a tensor operation if it would only reshape a constant.
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape.
-static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
-                                          TensorType result) {
-  if (source && source.isSplat() && result.hasStaticShape())
+static OpFoldResult
+reshapeConstantSource(DenseElementsAttr source, TensorType result,
+                      std::optional<Attribute> cst = std::nullopt) {
+  if (source && source.isSplat() && result.hasStaticShape() &&
+      (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
     return source.resizeSplat(result);
 
   return {};
@@ -4143,9 +4145,12 @@ bool PackOp::isLikePad() {
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+  std::optional<Attribute> paddingValue;
+  if (adaptor.getPaddingValue())
+    paddingValue = adaptor.getPaddingValue();
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
-          getResult().getType()))
+          adaptor.getDestType(), paddingValue))
     return reshapedSource;
   return {};
 }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9ab54fe9c133db..ac365c9d297e88 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -830,6 +830,39 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
 
 // -----
 
+// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
+//   CHECK-NOT: tensor.pack
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+  %pad = arith.constant 1.000000e-01 : f32
+  %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+  %0 = tensor.pack %cst
+    padding_value(%pad : f32)
+    outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+  return %0 : tensor<8x16x8x32xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+//       CHECK: tensor.pack
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+  %pad = arith.constant 0.0 : f32
+  %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+  %0 = tensor.pack %cst
+    padding_value(%pad : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32]
+    into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+  return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
 func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<31250x1200x16x1xf32>

@@ -4143,9 +4145,12 @@ bool PackOp::isLikePad() {
}

OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
std::optional<Attribute> paddingValue;
if (adaptor.getPaddingValue())
paddingValue = adaptor.getPaddingValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we should avoid the double call to getPaddingValue() here: this accessor is far from free (there is a loop hidden behind it)

@joker-eph joker-eph changed the title [mlir][tensor] Fix a bug in constant packing folding. [mlir][tensor] Fix tensor::PackOp fold() handling of padding Apr 2, 2024
if (source && source.isSplat() && result.hasStaticShape())
static OpFoldResult
reshapeConstantSource(DenseElementsAttr source, TensorType result,
std::optional<Attribute> cst = std::nullopt) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the documentation for this.

@joker-eph joker-eph changed the title [mlir][tensor] Fix tensor::PackOp fold() handling of padding [mlir][tensor] Fix tensor::PackOp fold() handling of padding value Apr 2, 2024
Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me after addressing Mehdi's comments.

@chelini
Copy link
Contributor

chelini commented Apr 2, 2024

FYI: @rengolin

@hanhanW hanhanW requested a review from joker-eph April 2, 2024 20:26
@hanhanW hanhanW merged commit c3e3d59 into llvm:main Apr 2, 2024
@hanhanW hanhanW deleted the pack-fold-cst-bug branch April 2, 2024 20:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants