Skip to content

[mlir][tensor] Add a pattern to simplify tensor.unpack to collpase shape #76607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Dec 30, 2023

No description provided.

@hanhanW hanhanW requested a review from chelini December 30, 2023 06:55
@hanhanW
Copy link
Contributor Author

hanhanW commented Dec 30, 2023

It depends on #76606, please take a look at 91c43c7 or wait until the PR is landed.

Copy link

github-actions bot commented Dec 30, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@hanhanW hanhanW force-pushed the simplify-unpack branch 2 times, most recently from 66e3432 to 91c43c7 Compare December 30, 2023 08:04
@hanhanW hanhanW marked this pull request as ready for review January 3, 2024 01:41
@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2024

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/76607.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+43-1)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 35b519e790d1c3..e8a09c4741043b 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -76,7 +76,6 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns that simplify `tensor.pack` and
 /// `tensor.unpack` operations.
-/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e20450c95ffd5f..cfd838e85c1b80 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   }
 };
 
+struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
+                       Type newOperandType, ArrayAttr reassociation) const {
+    if (operand.getType() == newOperandType)
+      return operand;
+    return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
+                                                    operand, reassociation);
+  }
+
+  LogicalResult matchAndRewrite(UnPackOp unpackOp,
+                                PatternRewriter &rewriter) const override {
+    if (!unpackOp.getOuterDimsPerm().empty()) {
+      return rewriter.notifyMatchFailure(unpackOp,
+                                         "expects no outer_dims_perm");
+    }
+
+    RankedTensorType sourceType = unpackOp.getSourceType();
+    RankedTensorType destType = unpackOp.getDestType();
+    if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
+      return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
+
+    ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
+    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
+      return rewriter.notifyMatchFailure(
+          unpackOp, "expects unpacking at the innermost dimension");
+    }
+
+    auto reassociation =
+        getReassociationIndicesForReshape(sourceType, destType);
+    if (!reassociation)
+      return failure();
+    Value collapsed = insertCollapse(
+        rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+        getReassociationIndicesAttribute(rewriter, *reassociation));
+    rewriter.replaceOp(unpackOp, collapsed);
+    return success();
+  }
+};
+
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
 }
 
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
-  patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
+      patterns.getContext());
 }
 
 } // namespace tensor
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index bdfe18acd86c53..b78ab9bb3fd87e 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
   %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
   return %0 : tensor<8x5x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_1d_to_collapse
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<8x32xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
+// CHECK:         return %[[COLLAPSED]]
+func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
+  %empty = tensor.empty() : tensor<256xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
+  return %0 : tensor<256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_to_partial_slice
+// CHECK-NOT:     tensor.collapse
+// CHECK:         tensor.unpack
+func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
+  %empty = tensor.empty() : tensor<255xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
+  return %0 : tensor<255xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK-NOT:     tensor.collapse
+// CHECK:         tensor.unpack
+func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
+  %size = arith.muli %d0, %c32 : index
+  %empty = tensor.empty(%size) : tensor<?xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x8x32xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<5x256xf32>
+func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
+  %empty = tensor.empty() : tensor<5x256xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
+  return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
+// CHECK-NOT:     tensor.collpase_shape
+// CHECK:         tensor.unpack
+func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
+  %empty = tensor.empty() : tensor<5x256xf32>
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
+  return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
+// CHECK-NOT:     tensor.collapse_shape
+// CHECK:         tensor.unpack
+func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
+  %empty = tensor.empty() : tensor<256x5xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
+  return %0 : tensor<256x5xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Han-Chung Wang (hanhanW)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/76607.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+43-1)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 35b519e790d1c3..e8a09c4741043b 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -76,7 +76,6 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns that simplify `tensor.pack` and
 /// `tensor.unpack` operations.
-/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e20450c95ffd5f..cfd838e85c1b80 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   }
 };
 
+struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
+  using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
+                       Type newOperandType, ArrayAttr reassociation) const {
+    if (operand.getType() == newOperandType)
+      return operand;
+    return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
+                                                    operand, reassociation);
+  }
+
+  LogicalResult matchAndRewrite(UnPackOp unpackOp,
+                                PatternRewriter &rewriter) const override {
+    if (!unpackOp.getOuterDimsPerm().empty()) {
+      return rewriter.notifyMatchFailure(unpackOp,
+                                         "expects no outer_dims_perm");
+    }
+
+    RankedTensorType sourceType = unpackOp.getSourceType();
+    RankedTensorType destType = unpackOp.getDestType();
+    if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
+      return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
+
+    ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
+    if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
+      return rewriter.notifyMatchFailure(
+          unpackOp, "expects unpacking at the innermost dimension");
+    }
+
+    auto reassociation =
+        getReassociationIndicesForReshape(sourceType, destType);
+    if (!reassociation)
+      return failure();
+    Value collapsed = insertCollapse(
+        rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+        getReassociationIndicesAttribute(rewriter, *reassociation));
+    rewriter.replaceOp(unpackOp, collapsed);
+    return success();
+  }
+};
+
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
 }
 
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
-  patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
+      patterns.getContext());
 }
 
 } // namespace tensor
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index bdfe18acd86c53..b78ab9bb3fd87e 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
   %0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
   return %0 : tensor<8x5x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_1d_to_collapse
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<8x32xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
+// CHECK:         return %[[COLLAPSED]]
+func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
+  %empty = tensor.empty() : tensor<256xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
+  return %0 : tensor<256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_to_partial_slice
+// CHECK-NOT:     tensor.collapse
+// CHECK:         tensor.unpack
+func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
+  %empty = tensor.empty() : tensor<255xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
+  return %0 : tensor<255xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK-NOT:     tensor.collapse
+// CHECK:         tensor.unpack
+func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
+  %size = arith.muli %d0, %c32 : index
+  %empty = tensor.empty(%size) : tensor<?xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
+// CHECK-SAME:    %[[ARG0:.+]]: tensor<5x8x32xf32>)
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<5x256xf32>
+func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
+  %empty = tensor.empty() : tensor<5x256xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
+  return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
+// CHECK-NOT:     tensor.collpase_shape
+// CHECK:         tensor.unpack
+func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
+  %empty = tensor.empty() : tensor<5x256xf32>
+  %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
+  return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
+// CHECK-NOT:     tensor.collapse_shape
+// CHECK:         tensor.unpack
+func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
+  %empty = tensor.empty() : tensor<256x5xf32>
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
+  return %0 : tensor<256x5xf32>
+}

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@hanhanW hanhanW merged commit 76cb0bb into llvm:main Jan 3, 2024
@hanhanW hanhanW deleted the simplify-unpack branch January 3, 2024 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants