Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 3e3e276

Browse files
committed
[mlir][vector][NFC] Change UnrollVectorPattern to not be statically dependent on an op type
Make UnrollVectorPattern inherit from RewritePattern instead of OpRewritePattern so that we don't need to create many patterns when applying to many different type of ops. Since we may want to apply the pattern to all arithmetic op, it is more convenient to filter dynamically. Differential Revision: https://reviews.llvm.org/D92635
1 parent 840e651 commit 3e3e276

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

mlir/include/mlir/Dialect/Vector/VectorTransforms.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct UnrollVectorOptions {
9191
/// Callback function that indicates whether vector unrolling should be
9292
/// attempted on the operation.
9393
FilterConstraintFnType filterConstraint = nullptr;
94-
UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) {
94+
UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
9595
filterConstraint = constraint;
9696
return *this;
9797
}
@@ -117,34 +117,32 @@ struct UnrollVectorOptions {
117117
};
118118
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
119119
/// declaratively.
120-
template <typename OpTy>
121-
struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
122-
using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
120+
struct UnrollVectorPattern : public RewritePattern {
121+
using FilterConstraintType = std::function<LogicalResult(Operation *op)>;
123122
UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
124-
: OpRewritePattern<OpTy>(context), options(options) {}
125-
LogicalResult matchAndRewrite(OpTy op,
123+
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {}
124+
LogicalResult matchAndRewrite(Operation *op,
126125
PatternRewriter &rewriter) const override {
127126
if (options.filterConstraint && failed(options.filterConstraint(op)))
128127
return failure();
129128
if (!options.nativeShape) {
130-
return op.emitError("vector unrolling expects the native shape or native"
131-
"shape call back function to be set");
129+
return op->emitError("vector unrolling expects the native shape or native"
130+
"shape call back function to be set");
132131
}
133-
auto unrollableVectorOp =
134-
dyn_cast<VectorUnrollOpInterface>(op.getOperation());
132+
auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
135133
if (!unrollableVectorOp)
136134
return failure();
137135
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
138136
if (!maybeUnrollShape)
139137
return failure();
140138
Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
141139
if (!targetShape)
142-
return op.emitError("failed to get target shape for vector unroll");
140+
return op->emitError("failed to get target shape for vector unroll");
143141
auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
144142
if (!maybeShapeRatio ||
145143
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
146144
return failure();
147-
if (std::is_same<OpTy, TransferWriteOp>::value) {
145+
if (isa<TransferWriteOp>(op)) {
148146
if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
149147
return failure();
150148
rewriter.eraseOp(op);

mlir/test/lib/Transforms/TestVectorTransforms.cpp

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,22 @@ struct TestVectorToVectorConversion
2727
void runOnFunction() override {
2828
OwningRewritePatternList patterns;
2929
auto *ctx = &getContext();
30-
patterns.insert<UnrollVectorPattern<AddFOp>>(
31-
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
32-
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
33-
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
30+
patterns.insert<UnrollVectorPattern>(
31+
ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
3432
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
3533
populateVectorToVectorTransformationPatterns(patterns, ctx);
3634
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
3735
}
36+
37+
private:
38+
// Return the target shape based on op type.
39+
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
40+
if (isa<AddFOp>(op))
41+
return SmallVector<int64_t, 4>(2, 2);
42+
if (isa<vector::ContractionOp>(op))
43+
return SmallVector<int64_t, 4>(3, 2);
44+
return llvm::None;
45+
}
3846
};
3947

4048
struct TestVectorSlicesConversion
@@ -120,8 +128,11 @@ struct TestVectorUnrollingPatterns
120128
void runOnFunction() override {
121129
MLIRContext *ctx = &getContext();
122130
OwningRewritePatternList patterns;
123-
patterns.insert<UnrollVectorPattern<AddFOp>>(
124-
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
131+
patterns.insert<UnrollVectorPattern>(
132+
ctx, UnrollVectorOptions()
133+
.setNativeShape(ArrayRef<int64_t>{2, 2})
134+
.setFilterConstraint(
135+
[](Operation *op) { return success(isa<AddFOp>(op)); }));
125136

126137
if (unrollBasedOnType) {
127138
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
@@ -137,12 +148,19 @@ struct TestVectorUnrollingPatterns
137148
}
138149
return nativeShape;
139150
};
140-
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
141-
ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
151+
patterns.insert<UnrollVectorPattern>(
152+
ctx, UnrollVectorOptions()
153+
.setNativeShapeFn(nativeShapeFn)
154+
.setFilterConstraint([](Operation *op) {
155+
return success(isa<ContractionOp>(op));
156+
}));
142157
} else {
143-
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
144-
ctx,
145-
UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
158+
patterns.insert<UnrollVectorPattern>(
159+
ctx, UnrollVectorOptions()
160+
.setNativeShape(ArrayRef<int64_t>{2, 2, 2})
161+
.setFilterConstraint([](Operation *op) {
162+
return success(isa<ContractionOp>(op));
163+
}));
146164
}
147165
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
148166
populateVectorToVectorTransformationPatterns(patterns, ctx);
@@ -273,10 +291,14 @@ struct TestVectorTransferUnrollingPatterns
273291
void runOnFunction() override {
274292
MLIRContext *ctx = &getContext();
275293
OwningRewritePatternList patterns;
276-
patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
277-
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
278-
patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
279-
ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
294+
patterns.insert<UnrollVectorPattern>(
295+
ctx,
296+
UnrollVectorOptions()
297+
.setNativeShape(ArrayRef<int64_t>{2, 2})
298+
.setFilterConstraint([](Operation *op) {
299+
return success(
300+
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
301+
}));
280302
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
281303
populateVectorToVectorTransformationPatterns(patterns, ctx);
282304
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));

0 commit comments

Comments
 (0)