Skip to content

Commit d6590c1

Browse files
authored
[MLIR] Add allow Insert/extract slice option to pack/unpack op (#117340)
This PR adds default option below. The new options will come as default to true and not change the original lowering behavior of pack and unpack op. - lowerPadLikeWithInsertSlice to packOp (with default = true) - lowerUnpadLikeWithExtractSlice to unPackOp (with default = true) The motivation of the PR is finer granular control of the lowering of pack and unpack Ops. This is useful in particular when we want to guarantee that there's no additional insertslice and extractslice that interfere with tiling. With the original lowering pipeline, packOp and unPackOp may be lowered to insertslice and extractslice when the high dimensions are unit dimensions and no transpose is invovled. Under such circumstances, such insert and extract slice ops will block producer/consumer fusion tile + fuse transforms. With this PR, we will be able to disable such lowering path and allow consumer fusion to go through as expected.
1 parent c5a21c1 commit d6590c1

File tree

6 files changed

+327
-17
lines changed

6 files changed

+327
-17
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
559559
Return handles to the newly produced pad, expand_shape and transpose ops.
560560
}];
561561

562-
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
562+
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
563+
DefaultValuedAttr<BoolAttr, "true">:$lowerPadLikeWithInsertSlice);
563564
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
564565
Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
565566
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -599,7 +600,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
599600
Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
600601
}];
601602

602-
let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
603+
let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
604+
DefaultValuedAttr<BoolAttr, "true">:$lowerUnpadLikeWithExtractSlice);
603605
let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
604606
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
605607
Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,8 @@ struct LowerPackResult {
11211121

11221122
/// Rewrite pack as pad + reshape + transpose.
11231123
FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
1124-
tensor::PackOp packOp);
1124+
tensor::PackOp packOp,
1125+
bool lowerPadLikeWithInsertSlice = true);
11251126

11261127
struct LowerUnPackOpResult {
11271128
tensor::EmptyOp emptyOp;
@@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult {
11311132
};
11321133

11331134
/// Rewrite pack as empty + transpose + reshape + extract_slice.
1134-
FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
1135-
tensor::UnPackOp unPackOp);
1135+
FailureOr<LowerUnPackOpResult>
1136+
lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
1137+
bool lowerUnpadLikeWithExtractSlice = true);
11361138

11371139
/// Struct to hold the result of a `pack` call.
11381140
struct PackResult {

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
11761176
transform::ApplyToEachResultList &transformResults,
11771177
transform::TransformState &state) {
11781178
rewriter.setInsertionPoint(target);
1179-
FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1179+
bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1180+
FailureOr<LowerPackResult> res =
1181+
lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
11801182
if (failed(res)) {
11811183
return mlir::emitSilenceableFailure(target->getLoc())
11821184
<< "cannot lower to pad + expand + transpose";
@@ -1196,7 +1198,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
11961198
transform::ApplyToEachResultList &transformResults,
11971199
transform::TransformState &state) {
11981200
rewriter.setInsertionPoint(target);
1199-
FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1201+
bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1202+
FailureOr<LowerUnPackOpResult> res =
1203+
lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
12001204
if (failed(res)) {
12011205
DiagnosedSilenceableFailure diag =
12021206
emitSilenceableError()

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ struct PackedOperandsDimList {
217217
} // namespace
218218

219219
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220-
tensor::PackOp packOp) {
220+
tensor::PackOp packOp,
221+
bool lowerPadLikeWithInsertSlice) {
221222
// 1. Filter out NYI cases.
222223
auto packedTensorType =
223224
cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -295,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
295296
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
296297
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
297298

298-
if (packOp.isLikePad()) {
299+
if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
299300
// Pack ops which operate as simple pads may not produce legal
300301
// tensor.insert_slice operations when the packed type does not rank reduce
301302
// to the padded type.
@@ -351,8 +352,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
351352
return LowerPackResult{padOp, reshapeOp, transposeOp};
352353
}
353354

354-
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
355-
tensor::UnPackOp unPackOp) {
355+
FailureOr<LowerUnPackOpResult>
356+
linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
357+
bool lowerUnpadLikeWithExtractSlice) {
356358
Location loc = unPackOp->getLoc();
357359
OpBuilder::InsertionGuard g(rewriter);
358360
rewriter.setInsertionPoint(unPackOp);
@@ -362,7 +364,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
362364

363365
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
364366
auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
365-
if (unPackOp.isLikeUnPad()) {
367+
if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
366368
// This unpack is just a plain unpad.
367369
// Just extract the slice from the higher ranked tensor.
368370
ArrayRef<int64_t> destShape = destTensorType.getShape();

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
9696

9797
// -----
9898

99+
// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
100+
// be lowered to insert_slice.
101+
// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
102+
func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
103+
%cst_0 = arith.constant 0.0 : f32
104+
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
105+
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
106+
// CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
107+
// CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
108+
// CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
109+
// CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
110+
%pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
111+
: tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
112+
return %pack : tensor<1x1x1x1x136x64x16x16xf32>
113+
}
114+
115+
module attributes {transform.with_named_sequence} {
116+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
117+
%pack = transform.structured.match ops{["tensor.pack"]} in %module_op
118+
: (!transform.any_op) -> !transform.op<"tensor.pack">
119+
transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">)
120+
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
121+
transform.yield
122+
}
123+
}
124+
125+
// -----
126+
99127
// Check that we don't lower the following pack as a pad.
100128
// Although all the outer most dimensions in the resulting shape are 1s,
101129
// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
233261

234262
// -----
235263

264+
// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
265+
// be lowered to extract_slice.
266+
// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
267+
func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
268+
%cst_0 = arith.constant 0.0 : f32
269+
270+
// tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
271+
// CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
272+
// CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
273+
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
274+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
275+
// CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
276+
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
277+
: tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
278+
return %pack : tensor<129x47x16x16xf32>
279+
}
280+
281+
module attributes {transform.with_named_sequence} {
282+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
283+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
284+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
285+
transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">)
286+
-> (!transform.op<"tensor.empty">,
287+
!transform.op<"linalg.transpose">,
288+
!transform.op<"tensor.collapse_shape">,
289+
!transform.op<"tensor.extract_slice">)
290+
transform.yield
291+
}
292+
}
293+
294+
// -----
295+
236296
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
237297
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
238298
%dest: tensor<200x4x16x100x16x32xi32>)
@@ -572,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?x
572632
module attributes {transform.with_named_sequence} {
573633
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
574634
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
575-
: (!transform.any_op) -> !transform.op<"tensor.unpack">
635+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
576636
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
577637
-> (!transform.op<"tensor.empty">,
578638
!transform.op<"linalg.transpose">,
@@ -627,9 +687,9 @@ module attributes {transform.with_named_sequence} {
627687
// CHECK-LABEL: @unpack_with_outer_dims_perm
628688
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
629689
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
630-
// CHECK: %[[TRAN:.*]] = linalg.transpose
631-
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
632-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
690+
// CHECK: %[[TRAN:.*]] = linalg.transpose
691+
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
692+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
633693
// CHECK-SAME: permutation = [1, 3, 0, 2]
634694
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
635695
// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
@@ -638,7 +698,7 @@ module attributes {transform.with_named_sequence} {
638698
// CHECK: linalg.copy ins(%[[SLICE]]
639699
// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
640700
func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
641-
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
701+
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
642702
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
643703
return %unpack : tensor<32x64xf32>
644704
}

0 commit comments

Comments
 (0)