Skip to content

Commit 76cb0bb

Browse files
authored
[mlir][tensor] Add a pattern to simplify tensor.unpack to collpase shape (#76607)
1 parent 3f9f8ef commit 76cb0bb

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
7676

7777
/// Populates `patterns` with patterns that simplify `tensor.pack` and
7878
/// `tensor.unpack` operations.
79-
/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
8079
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
8180

8281
/// Populates `patterns` with patterns that fold operations like `tensor.pad`

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
6161
}
6262
};
6363

64+
struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
65+
using OpRewritePattern<UnPackOp>::OpRewritePattern;
66+
67+
Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
68+
Type newOperandType, ArrayAttr reassociation) const {
69+
if (operand.getType() == newOperandType)
70+
return operand;
71+
return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
72+
operand, reassociation);
73+
}
74+
75+
LogicalResult matchAndRewrite(UnPackOp unpackOp,
76+
PatternRewriter &rewriter) const override {
77+
if (!unpackOp.getOuterDimsPerm().empty()) {
78+
return rewriter.notifyMatchFailure(unpackOp,
79+
"expects no outer_dims_perm");
80+
}
81+
82+
RankedTensorType sourceType = unpackOp.getSourceType();
83+
RankedTensorType destType = unpackOp.getDestType();
84+
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
85+
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
86+
87+
ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
88+
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
89+
return rewriter.notifyMatchFailure(
90+
unpackOp, "expects unpacking at the innermost dimension");
91+
}
92+
93+
auto reassociation =
94+
getReassociationIndicesForReshape(sourceType, destType);
95+
if (!reassociation)
96+
return failure();
97+
Value collapsed = insertCollapse(
98+
rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
99+
getReassociationIndicesAttribute(rewriter, *reassociation));
100+
rewriter.replaceOp(unpackOp, collapsed);
101+
return success();
102+
}
103+
};
104+
64105
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
65106
/// the pad op has zero low paddings, or if `pack` has no padding values.
66107
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
191232
}
192233

193234
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
194-
patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
235+
patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
236+
patterns.getContext());
195237
}
196238

197239
} // namespace tensor

mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
5656
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
5757
return %0 : tensor<8x5x32xf32>
5858
}
59+
60+
// -----
61+
62+
// CHECK-LABEL: func.func @unpack_1d_to_collapse
63+
// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>)
64+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
65+
// CHECK: return %[[COLLAPSED]]
66+
func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
67+
%empty = tensor.empty() : tensor<256xf32>
68+
%0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
69+
return %0 : tensor<256xf32>
70+
}
71+
72+
// -----
73+
74+
// CHECK-LABEL: func.func @unpack_to_partial_slice
75+
// CHECK-NOT: tensor.collapse
76+
// CHECK: tensor.unpack
77+
func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
78+
%empty = tensor.empty() : tensor<255xf32>
79+
%0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
80+
return %0 : tensor<255xf32>
81+
}
82+
83+
// -----
84+
85+
// CHECK-LABEL: func.func @unpack_dynamic
86+
// CHECK-NOT: tensor.collapse
87+
// CHECK: tensor.unpack
88+
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
89+
%c32 = arith.constant 32 : index
90+
%c0 = arith.constant 0 : index
91+
%d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
92+
%size = arith.muli %d0, %c32 : index
93+
%empty = tensor.empty(%size) : tensor<?xf32>
94+
%0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
95+
return %0 : tensor<?xf32>
96+
}
97+
98+
// -----
99+
100+
// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
101+
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>)
102+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
103+
// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32>
104+
func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
105+
%empty = tensor.empty() : tensor<5x256xf32>
106+
%0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
107+
return %0 : tensor<5x256xf32>
108+
}
109+
110+
// -----
111+
112+
// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
113+
// CHECK-NOT: tensor.collpase_shape
114+
// CHECK: tensor.unpack
115+
func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
116+
%empty = tensor.empty() : tensor<5x256xf32>
117+
%0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
118+
return %0 : tensor<5x256xf32>
119+
}
120+
121+
// -----
122+
123+
// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
124+
// CHECK-NOT: tensor.collapse_shape
125+
// CHECK: tensor.unpack
126+
func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
127+
%empty = tensor.empty() : tensor<256x5xf32>
128+
%0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
129+
return %0 : tensor<256x5xf32>
130+
}

0 commit comments

Comments
 (0)