-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Improve tensor.pack simplication pattern. #76606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
We can rewrite the op to tensor.expand_shape if the packing only happens on inner most dimension.
@llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesA tensor.pack op can be rewritten to a tensor.expand_shape op if the packing only happens on inner most dimension. This also formats the lit checks better. Full diff: https://github.com/llvm/llvm-project/pull/76606.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 67651a2e38c82d..e20450c95ffd5f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -35,10 +35,20 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+ if (!packOp.getOuterDimsPerm().empty())
+ return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm");
+
RankedTensorType sourceType = packOp.getSourceType();
RankedTensorType destType = packOp.getDestType();
- if (sourceType.getRank() != 1 || packOp.getPaddingValue())
- return failure();
+ ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+ return rewriter.notifyMatchFailure(
+ packOp, "expects packing at the innermost dimension");
+ }
+
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 049076a67bae53..bdfe18acd86c53 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
-// CHECK: func.func @single_dim_packing(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
-// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
+// CHECK-LABEL: func.func @single_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
@@ -12,13 +12,47 @@ func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
// -----
-// CHECK: func.func @single_dim_packing_with_padding(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
-// CHECK-NOT: tensor.expand_shape
-// CHECK: tensor.pack
+// CHECK-LABEL: func.func @single_dim_packing_with_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
return %0 : tensor<8x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
+func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
+ %empty = tensor.empty() : tensor<5x8x32xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
+ return %0 : tensor<5x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @packing_with_outer_dims_perm(
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> {
+ %empty = tensor.empty() : tensor<8x5x32xf32>
+ %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32>
+ return %0 : tensor<8x5x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_packing(
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> {
+ %empty = tensor.empty() : tensor<8x5x32xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
+ return %0 : tensor<8x5x32xf32>
+}
|
@llvm/pr-subscribers-mlir-tensor Author: Han-Chung Wang (hanhanW) ChangesA tensor.pack op can be rewritten to a tensor.expand_shape op if the packing only happens on inner most dimension. This also formats the lit checks better. Full diff: https://github.com/llvm/llvm-project/pull/76606.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 67651a2e38c82d..e20450c95ffd5f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -35,10 +35,20 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+ if (!packOp.getOuterDimsPerm().empty())
+ return rewriter.notifyMatchFailure(packOp, "expects no outer_dims_perm");
+
RankedTensorType sourceType = packOp.getSourceType();
RankedTensorType destType = packOp.getDestType();
- if (sourceType.getRank() != 1 || packOp.getPaddingValue())
- return failure();
+ ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
+ return rewriter.notifyMatchFailure(
+ packOp, "expects packing at the innermost dimension");
+ }
+
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 049076a67bae53..bdfe18acd86c53 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
-// CHECK: func.func @single_dim_packing(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
-// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
+// CHECK-LABEL: func.func @single_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256xf32> -> tensor<8x32xf32>
@@ -12,13 +12,47 @@ func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
// -----
-// CHECK: func.func @single_dim_packing_with_padding(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
-// CHECK-NOT: tensor.expand_shape
-// CHECK: tensor.pack
+// CHECK-LABEL: func.func @single_dim_packing_with_padding(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<255xf32>)
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.pack %arg0 padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<255xf32> -> tensor<8x32xf32>
return %0 : tensor<8x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_packing(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
+func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
+ %empty = tensor.empty() : tensor<5x8x32xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<5x8x32xf32>
+ return %0 : tensor<5x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @packing_with_outer_dims_perm(
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @packing_with_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<8x5x32xf32> {
+ %empty = tensor.empty() : tensor<8x5x32xf32>
+ %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x256xf32> -> tensor<8x5x32xf32>
+ return %0 : tensor<8x5x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_packing(
+// CHECK-NOT: tensor.expand_shape
+// CHECK: tensor.pack
+func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x5x32xf32> {
+ %empty = tensor.empty() : tensor<8x5x32xf32>
+ %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
+ return %0 : tensor<8x5x32xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks good to me.
A tensor.pack op can be rewritten to a tensor.expand_shape op if the packing only happens on inner most dimension.
This also formats the lit checks better.