Skip to content

[mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr #92234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

adam-smnk
Copy link
Contributor

Extends pack/unpack perm attribute checker to account for cases when the optional outer_dims_perm attribute might be missing in one operation and the other one has explicit identity permutation. This enables canonicalizer to fold more unpack(pack(x)) variants.

Extend pack/unpack perm attribute checker to account for cases when
the optional outer_dims_perm attribute might be missing in one operation
and the other one has explicit identity permutation.
This enables canonicalizer to fold more unpack(pack(x)) variants.
@llvmbot
Copy link
Member

llvmbot commented May 15, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Adam Siemieniuk (adam-smnk)

Changes

Extends pack/unpack perm attribute checker to account for cases when the optional outer_dims_perm attribute might be missing in one operation and the other one has explicit identity permutation. This enables canonicalizer to fold more unpack(pack(x)) variants.


Full diff: https://github.com/llvm/llvm-project/pull/92234.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-1)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+26)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 414bd7459af8f..8ef447cf53a37 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4112,7 +4112,13 @@ Speculation::Speculatability PackOp::getSpeculatability() {
 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
   if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
     return false;
-  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
+  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
+    return true;
+  // Outer dims permutation is optional.
+  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
+  // identity permutation.
+  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
+         isIdentityPermutation(unPackOp.getOuterDimsPerm());
 }
 
 // Return true if pack and unpack have the same tiles.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8036d996d2324..b5a82eb3e9035 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2252,6 +2252,32 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
 
 // -----
 
+// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+  %tensor_empty = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+  %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+  return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
+// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
+// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
+// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
+func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
+  %tensor_empty = tensor.empty() : tensor<128x128xf32>
+  %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
+  %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
+  return %packed : tensor<16x16x?x?xf32>
+}
+
+// -----
+
 // CHECK: func.func @invalid_empty_negative_size
 // CHECK: %[[IDX:.*]] = index.constant
 // CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>

@adam-smnk adam-smnk requested a review from hanhanW May 15, 2024 11:30
@adam-smnk adam-smnk merged commit dcd32bd into llvm:main May 16, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants