Skip to content

Commit 012f6d4

Browse files
committed
[NFC] Add allowInsertSliceLowering to packOp and allowExtractSliceLowering to UnPackOp
1 parent 667e1fa commit 012f6d4

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
548548
Return handles to the newly produced pad, expand_shape and transpose ops.
549549
}];
550550

551-
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
551+
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
552+
DefaultValuedAttr<BoolAttr, "true">:$allowInsertSliceLowering);
552553
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
553554
Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
554555
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -588,7 +589,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
588589
Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
589590
}];
590591

591-
let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
592+
let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
593+
DefaultValuedAttr<BoolAttr, "true">:$allowExtractSliceLowering);
592594
let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
593595
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
594596
Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,8 @@ struct LowerPackResult {
11211121

11221122
/// Rewrite pack as pad + reshape + transpose.
11231123
FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
1124-
tensor::PackOp packOp);
1124+
tensor::PackOp packOp,
1125+
bool allowInsertSliceLowering = true);
11251126

11261127
struct LowerUnPackOpResult {
11271128
tensor::EmptyOp emptyOp;
@@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult {
11311132
};
11321133

11331134
/// Rewrite pack as empty + transpose + reshape + extract_slice.
1134-
FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
1135-
tensor::UnPackOp unPackOp);
1135+
FailureOr<LowerUnPackOpResult>
1136+
lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
1137+
bool allowExtractSliceLowering = true);
11361138

11371139
/// Struct to hold the result of a `pack` call.
11381140
struct PackResult {

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
11711171
transform::ApplyToEachResultList &transformResults,
11721172
transform::TransformState &state) {
11731173
rewriter.setInsertionPoint(target);
1174-
FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1174+
bool allowInsertSliceLowering = getAllowInsertSliceLowering();
1175+
FailureOr<LowerPackResult> res =
1176+
lowerPack(rewriter, target, allowInsertSliceLowering);
11751177
if (failed(res)) {
11761178
return mlir::emitSilenceableFailure(target->getLoc())
11771179
<< "cannot lower to pad + expand + transpose";
@@ -1191,7 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
11911193
transform::ApplyToEachResultList &transformResults,
11921194
transform::TransformState &state) {
11931195
rewriter.setInsertionPoint(target);
1194-
FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1196+
bool allowExtractSliceLowering = getAllowExtractSliceLowering();
1197+
FailureOr<LowerUnPackOpResult> res =
1198+
lowerUnPack(rewriter, target, allowExtractSliceLowering);
11951199
if (failed(res)) {
11961200
DiagnosedSilenceableFailure diag =
11971201
emitSilenceableError()

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ struct PackedOperandsDimList {
217217
} // namespace
218218

219219
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220-
tensor::PackOp packOp) {
220+
tensor::PackOp packOp,
221+
bool allowInsertSliceLowering) {
221222
// 1. Filter out NYI cases.
222223
auto packedTensorType =
223224
cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -295,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
295296
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
296297
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
297298

298-
if (packOp.isLikePad()) {
299+
if (allowInsertSliceLowering && packOp.isLikePad()) {
299300
// Pack ops which operate as simple pads may not produce legal
300301
// tensor.insert_slice operations when the packed type does not rank reduce
301302
// to the padded type.
@@ -351,8 +352,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
351352
return LowerPackResult{padOp, reshapeOp, transposeOp};
352353
}
353354

354-
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
355-
tensor::UnPackOp unPackOp) {
355+
FailureOr<LowerUnPackOpResult>
356+
linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
357+
bool allowExtractSliceLowering) {
356358
Location loc = unPackOp->getLoc();
357359
OpBuilder::InsertionGuard g(rewriter);
358360
rewriter.setInsertionPoint(unPackOp);
@@ -362,7 +364,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
362364

363365
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
364366
auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
365-
if (unPackOp.isLikeUnPad()) {
367+
if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) {
366368
// This unpack is just a plain unpad.
367369
// Just extract the slice from the higher ranked tensor.
368370
ArrayRef<int64_t> destShape = destTensorType.getShape();

0 commit comments

Comments
 (0)