-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Add a pattern to simplify tensor.unpack to collpase shape #76607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
66e3432
to
91c43c7
Compare
91c43c7
to
6b47c4c
Compare
@llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesFull diff: https://github.com/llvm/llvm-project/pull/76607.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 35b519e790d1c3..e8a09c4741043b 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -76,7 +76,6 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns that simplify `tensor.pack` and
/// `tensor.unpack` operations.
-/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e20450c95ffd5f..cfd838e85c1b80 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
}
};
+struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType, ArrayAttr reassociation) const {
+ if (operand.getType() == newOperandType)
+ return operand;
+ return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
+ operand, reassociation);
+ }
+
+ LogicalResult matchAndRewrite(UnPackOp unpackOp,
+ PatternRewriter &rewriter) const override {
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no outer_dims_perm");
+ }
+
+ RankedTensorType sourceType = unpackOp.getSourceType();
+ RankedTensorType destType = unpackOp.getDestType();
+ if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
+ return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
+
+ ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
+ return rewriter.notifyMatchFailure(
+ unpackOp, "expects unpacking at the innermost dimension");
+ }
+
+ auto reassociation =
+ getReassociationIndicesForReshape(sourceType, destType);
+ if (!reassociation)
+ return failure();
+ Value collapsed = insertCollapse(
+ rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+ getReassociationIndicesAttribute(rewriter, *reassociation));
+ rewriter.replaceOp(unpackOp, collapsed);
+ return success();
+ }
+};
+
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
}
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
- patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+ patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
+ patterns.getContext());
}
} // namespace tensor
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index bdfe18acd86c53..b78ab9bb3fd87e 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
return %0 : tensor<8x5x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_1d_to_collapse
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
+ %empty = tensor.empty() : tensor<256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
+ return %0 : tensor<256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_to_partial_slice
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
+ %empty = tensor.empty() : tensor<255xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
+ return %0 : tensor<255xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
+ %size = arith.muli %d0, %c32 : index
+ %empty = tensor.empty(%size) : tensor<?xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32>
+func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
+// CHECK-NOT: tensor.collpase_shape
+// CHECK: tensor.unpack
+func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
+// CHECK-NOT: tensor.collapse_shape
+// CHECK: tensor.unpack
+func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
+ %empty = tensor.empty() : tensor<256x5xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
+ return %0 : tensor<256x5xf32>
+}
|
@llvm/pr-subscribers-mlir-tensor Author: Han-Chung Wang (hanhanW) ChangesFull diff: https://github.com/llvm/llvm-project/pull/76607.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 35b519e790d1c3..e8a09c4741043b 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -76,7 +76,6 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns that simplify `tensor.pack` and
/// `tensor.unpack` operations.
-/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e20450c95ffd5f..cfd838e85c1b80 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
}
};
+struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType, ArrayAttr reassociation) const {
+ if (operand.getType() == newOperandType)
+ return operand;
+ return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
+ operand, reassociation);
+ }
+
+ LogicalResult matchAndRewrite(UnPackOp unpackOp,
+ PatternRewriter &rewriter) const override {
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no outer_dims_perm");
+ }
+
+ RankedTensorType sourceType = unpackOp.getSourceType();
+ RankedTensorType destType = unpackOp.getDestType();
+ if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
+ return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
+
+ ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
+ return rewriter.notifyMatchFailure(
+ unpackOp, "expects unpacking at the innermost dimension");
+ }
+
+ auto reassociation =
+ getReassociationIndicesForReshape(sourceType, destType);
+ if (!reassociation)
+ return failure();
+ Value collapsed = insertCollapse(
+ rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+ getReassociationIndicesAttribute(rewriter, *reassociation));
+ rewriter.replaceOp(unpackOp, collapsed);
+ return success();
+ }
+};
+
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
}
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
- patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+ patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
+ patterns.getContext());
}
} // namespace tensor
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index bdfe18acd86c53..b78ab9bb3fd87e 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
return %0 : tensor<8x5x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_1d_to_collapse
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
+ %empty = tensor.empty() : tensor<256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
+ return %0 : tensor<256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_to_partial_slice
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
+ %empty = tensor.empty() : tensor<255xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
+ return %0 : tensor<255xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
+ %size = arith.muli %d0, %c32 : index
+ %empty = tensor.empty(%size) : tensor<?xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32>
+func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
+// CHECK-NOT: tensor.collpase_shape
+// CHECK: tensor.unpack
+func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
+// CHECK-NOT: tensor.collapse_shape
+// CHECK: tensor.unpack
+func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
+ %empty = tensor.empty() : tensor<256x5xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
+ return %0 : tensor<256x5xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
No description provided.