Skip to content

Commit 4b14205

Browse files
authored
[mlir][tensor] Centralize pack/unpack related patterns. (#76603)
The revision moves pack/unpack related patterns to PackAndUnpackPatterns.cpp. This follows the convention like other tensor ops. It also renames `populateSimplifyTensorPack` to `populateSimplifyPackAndUnpackPatterns` and adds a TODO item for tensor.unpack op.
1 parent 8346e86 commit 4b14205

File tree

7 files changed

+48
-50
lines changed

7 files changed

+48
-50
lines changed

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,6 @@ void populateFoldConstantExtractSlicePatterns(
163163
return false;
164164
});
165165

166-
/// Patterns to simplify tensor.pack.
167-
void populateSimplifyTensorPack(RewritePatternSet &patterns);
168-
169166
} // namespace tensor
170167
} // namespace mlir
171168

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
7474
/// that it can be bufferized into a sequence of copies.
7575
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
7676

77+
/// Populates `patterns` with patterns that simplify `tensor.pack` and
78+
/// `tensor.unpack` operations.
79+
/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
80+
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
81+
7782
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
7883
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
7984
/// respectively.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3466,44 +3466,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
34663466
// PackOp/UnPackOp Common
34673467
//===----------------------------------------------------------------------===//
34683468

3469-
namespace {
3470-
3471-
/// Packing one-dimensional tensor can be expressed as an expand shape op.
3472-
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
3473-
using OpRewritePattern<PackOp>::OpRewritePattern;
3474-
3475-
Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
3476-
Type newOperandType, ArrayAttr reassociation) const {
3477-
if (operand.getType() == newOperandType)
3478-
return operand;
3479-
return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
3480-
reassociation);
3481-
}
3482-
3483-
LogicalResult matchAndRewrite(PackOp packOp,
3484-
PatternRewriter &rewriter) const override {
3485-
RankedTensorType sourceType = packOp.getSourceType();
3486-
RankedTensorType destType = packOp.getDestType();
3487-
if (sourceType.getRank() != 1 || packOp.getPaddingValue())
3488-
return failure();
3489-
auto reassociation =
3490-
getReassociationIndicesForReshape(sourceType, destType);
3491-
if (!reassociation)
3492-
return failure();
3493-
Value expanded = insertExpand(
3494-
rewriter, packOp.getLoc(), packOp.getSource(), destType,
3495-
getReassociationIndicesAttribute(rewriter, *reassociation));
3496-
rewriter.replaceOp(packOp, expanded);
3497-
return success();
3498-
}
3499-
};
3500-
3501-
} // namespace
3502-
3503-
void mlir::tensor::populateSimplifyTensorPack(RewritePatternSet &patterns) {
3504-
patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
3505-
}
3506-
35073469
template <typename OpTy>
35083470
static LogicalResult
35093471
reifyResultShapesImpl(OpTy op, OpBuilder &builder,

mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ add_mlir_dialect_library(MLIRTensorTransforms
44
ConcatOpPatterns.cpp
55
EmptyOpPatterns.cpp
66
ExtractSliceFromReshapeUtils.cpp
7-
FoldIntoPackAndUnpackPatterns.cpp
87
FoldTensorSubsetOps.cpp
98
IndependenceTransforms.cpp
109
MergeConsecutiveInsertExtractSlicePatterns.cpp
10+
PackAndUnpackPatterns.cpp
1111
ReshapePatterns.cpp
1212
RewriteAsConstant.cpp
1313
SwapExtractSliceWithProducerPatterns.cpp

mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp renamed to mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,36 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
2121
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
2222
}
2323

24+
/// Packing one-dimensional tensor can be expressed as an expand shape op.
25+
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
26+
using OpRewritePattern<PackOp>::OpRewritePattern;
27+
28+
Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
29+
Type newOperandType, ArrayAttr reassociation) const {
30+
if (operand.getType() == newOperandType)
31+
return operand;
32+
return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
33+
reassociation);
34+
}
35+
36+
LogicalResult matchAndRewrite(PackOp packOp,
37+
PatternRewriter &rewriter) const override {
38+
RankedTensorType sourceType = packOp.getSourceType();
39+
RankedTensorType destType = packOp.getDestType();
40+
if (sourceType.getRank() != 1 || packOp.getPaddingValue())
41+
return failure();
42+
auto reassociation =
43+
getReassociationIndicesForReshape(sourceType, destType);
44+
if (!reassociation)
45+
return failure();
46+
Value expanded = insertExpand(
47+
rewriter, packOp.getLoc(), packOp.getSource(), destType,
48+
getReassociationIndicesAttribute(rewriter, *reassociation));
49+
rewriter.replaceOp(packOp, expanded);
50+
return success();
51+
}
52+
};
53+
2454
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
2555
/// the pad op has zero low paddings, or if `pack` has no padding values.
2656
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -150,5 +180,9 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
150180
patterns.getContext());
151181
}
152182

183+
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
184+
patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
185+
}
186+
153187
} // namespace tensor
154188
} // namespace mlir

mlir/test/Dialect/Tensor/simplify-tensor-pack.mlir renamed to mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-patterns" %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-simplify-pack-unpack-patterns" %s | FileCheck %s
22

33
// CHECK: func.func @single_dim_packing(
44
// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ struct TestTensorTransforms
8484
"the extract_slice of collapse_shape pattern"),
8585
llvm::cl::init(false)};
8686

87-
Option<bool> testSimplifyPackPatterns{
88-
*this, "test-simplify-pack-patterns",
89-
llvm::cl::desc("Test patterns to simplify tensor.pack"),
87+
Option<bool> testSimplifyPackUnpackPatterns{
88+
*this, "test-simplify-pack-unpack-patterns",
89+
llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"),
9090
llvm::cl::init(false)};
9191

9292
Option<bool> testTrackingListener{
@@ -137,9 +137,9 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
137137
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
138138
}
139139

140-
static void applySimplifyPackPatterns(Operation *rootOp) {
140+
static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
141141
RewritePatternSet patterns(rootOp->getContext());
142-
tensor::populateSimplifyTensorPack(patterns);
142+
tensor::populateSimplifyPackAndUnpackPatterns(patterns);
143143
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
144144
}
145145

@@ -376,8 +376,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
376376

377377
void TestTensorTransforms::runOnOperation() {
378378
Operation *rootOp = getOperation();
379-
if (testSimplifyPackPatterns)
380-
applySimplifyPackPatterns(rootOp);
379+
if (testSimplifyPackUnpackPatterns)
380+
applySimplifyPackUnpackPatterns(rootOp);
381381
if (testFoldConstantExtractSlice)
382382
applyFoldConstantExtractSlicePatterns(rootOp);
383383
if (testFoldConsecutiveInsertExtractSlice)

0 commit comments

Comments
 (0)