Skip to content

Commit 9a79b1b

Browse files
committed
[mlir] Add peeling xform to Codegen Strategy
This patch adds the knobs to use peeling in the codegen strategy infrastructure. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D126842
1 parent 5ac2615 commit 9a79b1b

File tree

6 files changed

+182
-0
lines changed

6 files changed

+182
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ createLinalgStrategyInterchangePass(
127127
const linalg::LinalgTransformationFilter &filter =
128128
linalg::LinalgTransformationFilter());
129129

130+
/// Create a LinalgStrategyPeelPass.
131+
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyPeelPass(
132+
StringRef opName = "",
133+
linalg::LinalgPeelOptions opt = linalg::LinalgPeelOptions(),
134+
const linalg::LinalgTransformationFilter &filter =
135+
linalg::LinalgTransformationFilter());
136+
130137
/// Create a LinalgStrategyVectorizePass.
131138
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyVectorizePass(
132139
StringRef opName = "",

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,22 @@ def LinalgStrategyInterchangePass
272272
];
273273
}
274274

275+
def LinalgStrategyPeelPass
276+
: Pass<"linalg-strategy-peel-pass", "func::FuncOp"> {
277+
let summary = "Configurable pass to apply pattern-based linalg peeling.";
278+
let constructor = "mlir::createLinalgStrategyPeelPass()";
279+
let dependentDialects = [
280+
"linalg::LinalgDialect",
281+
"scf::SCFDialect"
282+
];
283+
let options = [
284+
Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
285+
"Which func op is the anchor to latch on.">,
286+
Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
287+
"Which linalg op within the func is the anchor to latch on.">,
288+
];
289+
}
290+
275291
def LinalgStrategyVectorizePass
276292
: Pass<"linalg-strategy-vectorize-pass", "func::FuncOp"> {
277293
let summary = "Configurable pass to apply pattern-based linalg vectorization.";

mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,26 @@ struct Decompose : public Transformation {
141141
}
142142
};
143143

144+
/// Represent one application of createLinalgStrategyPeelPass.
145+
struct Peel : public Transformation {
146+
explicit Peel(linalg::LinalgPeelOptions options,
147+
LinalgTransformationFilter::FilterFunction f = nullptr)
148+
: Transformation(std::move(f)), opName(), options(options) {}
149+
150+
Peel(StringRef name, linalg::LinalgPeelOptions options,
151+
LinalgTransformationFilter::FilterFunction f = nullptr)
152+
: Transformation(std::move(f)), opName(name), options(options) {}
153+
154+
void addToPassPipeline(OpPassManager &pm,
155+
LinalgTransformationFilter m) const override {
156+
pm.addPass(createLinalgStrategyPeelPass(opName, options, m));
157+
}
158+
159+
private:
160+
std::string opName;
161+
linalg::LinalgPeelOptions options;
162+
};
163+
144164
/// Represent one application of createLinalgStrategyVectorizePass.
145165
struct Vectorize : public Transformation {
146166
explicit Vectorize(linalg::LinalgVectorizationOptions options,
@@ -288,6 +308,20 @@ struct CodegenStrategy {
288308
decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) {
289309
return b ? decompose(std::move(f)) : *this;
290310
}
311+
/// Append a pattern to peel 'LinalgOpType'.
312+
CodegenStrategy &
313+
peel(StringRef opName, const LinalgPeelOptions &options,
314+
const LinalgTransformationFilter::FilterFunction &f = nullptr) {
315+
transformationSequence.emplace_back(
316+
std::make_unique<Peel>(opName, options, f));
317+
return *this;
318+
}
319+
/// Conditionally append a pattern to peel 'LinalgOpType'.
320+
CodegenStrategy &
321+
peelIf(bool b, StringRef opName, const LinalgPeelOptions &options,
322+
LinalgTransformationFilter::FilterFunction f = nullptr) {
323+
return b ? peel(opName, options, std::move(f)) : *this;
324+
}
291325
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
292326
CodegenStrategy &
293327
vectorize(StringRef opName,

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ struct TiledLinalgOp {
129129
FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
130130
const LinalgTilingOptions &options);
131131

132+
/// Peel and canonicalize 'loops'.
133+
void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
134+
132135
/// Peel the loops of a TiledLinalgOp.
133136
void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
134137
ArrayRef<int64_t> peeledLoops,
@@ -965,6 +968,49 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
965968
: LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
966969
};
967970

971+
///
972+
/// Linalg peeling patterns.
973+
///
974+
975+
/// Compute the loops to peel and return them in a SmallVector. Loops will be
976+
/// peeled in order of appearance in the SmallVector. This order will impact the
977+
/// output IR. If an inner-to-outer order is provided, the peeled iterations of
978+
/// the outer loops will also contain the peeled inner loops. If an
979+
/// outer-to-inner order is provided, the peeled iterations of the outer loops
980+
/// will not contain any peeled inner loops.
981+
using LoopsToPeelComputationFunction = std::function<void(
982+
OpBuilder &, Operation *, SmallVectorImpl<scf::ForOp> &)>;
983+
984+
struct LinalgPeelOptions {
985+
LoopsToPeelComputationFunction loopsToPeelComputationFunction = nullptr;
986+
};
987+
988+
/// `filter` controls LinalgTransformMarker matching and update when specified.
989+
struct LinalgPeelingPattern : public OpInterfaceRewritePattern<LinalgOp> {
990+
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
991+
LinalgPeelingPattern(
992+
MLIRContext *context,
993+
LinalgTransformationFilter f = LinalgTransformationFilter(),
994+
LinalgPeelOptions options = LinalgPeelOptions(),
995+
PatternBenefit benefit = 1);
996+
997+
/// Construct a pattern specifically applied to `opName`.
998+
LinalgPeelingPattern(
999+
StringRef opName, MLIRContext *context,
1000+
LinalgPeelOptions options = LinalgPeelOptions(),
1001+
LinalgTransformationFilter f = LinalgTransformationFilter(),
1002+
PatternBenefit benefit = 1);
1003+
1004+
LogicalResult matchAndRewrite(LinalgOp linalgOp,
1005+
PatternRewriter &rewriter) const override;
1006+
1007+
private:
1008+
/// LinalgTransformMarker handles special attribute manipulations.
1009+
const LinalgTransformationFilter filter;
1010+
/// Peeling options.
1011+
const LinalgPeelOptions options;
1012+
};
1013+
9681014
///
9691015
/// Linalg vectorization patterns.
9701016
///

mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,40 @@ struct LinalgStrategyPromotePass
262262
LinalgTransformationFilter filter;
263263
};
264264

265+
/// Configurable pass to apply pattern-based linalg peeling.
266+
struct LinalgStrategyPeelPass
267+
: public LinalgStrategyPeelPassBase<LinalgStrategyPeelPass> {
268+
269+
LinalgStrategyPeelPass() = default;
270+
271+
LinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt,
272+
LinalgTransformationFilter filt)
273+
: options(opt), filter(std::move(filt)) {
274+
this->anchorOpName.setValue(opName.str());
275+
}
276+
277+
void runOnOperation() override {
278+
auto funcOp = getOperation();
279+
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
280+
return;
281+
282+
RewritePatternSet peelingPatterns(funcOp.getContext());
283+
if (!anchorOpName.empty()) {
284+
peelingPatterns.add<LinalgPeelingPattern>(
285+
anchorOpName, funcOp.getContext(), options, filter);
286+
} else {
287+
peelingPatterns.add<LinalgPeelingPattern>(funcOp.getContext(), filter,
288+
options);
289+
}
290+
if (failed(
291+
applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns))))
292+
return signalPassFailure();
293+
}
294+
295+
LinalgPeelOptions options;
296+
LinalgTransformationFilter filter;
297+
};
298+
265299
/// Configurable pass to apply pattern-based linalg vectorization.
266300
struct LinalgStrategyVectorizePass
267301
: public LinalgStrategyVectorizePassBase<LinalgStrategyVectorizePass> {
@@ -506,6 +540,13 @@ mlir::createLinalgStrategyInterchangePass(
506540
filter);
507541
}
508542

543+
/// Create a LinalgStrategyPeelPass.
544+
std::unique_ptr<OperationPass<func::FuncOp>>
545+
mlir::createLinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt,
546+
const LinalgTransformationFilter &filter) {
547+
return std::make_unique<LinalgStrategyPeelPass>(opName, opt, filter);
548+
}
549+
509550
/// Create a LinalgStrategyVectorizePass.
510551
std::unique_ptr<OperationPass<func::FuncOp>>
511552
mlir::createLinalgStrategyVectorizePass(

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,15 @@ static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
323323
.Default([&](Operation *op) { return op->getResults(); });
324324
}
325325

326+
/// Peel and canonicalize 'loops'.
327+
void mlir::linalg::peelLoops(RewriterBase &rewriter,
328+
ArrayRef<scf::ForOp> loops) {
329+
for (auto loopOp : loops) {
330+
SmallVector<Value, 4> loopResults;
331+
loopResults = peelLoop(rewriter, loopOp);
332+
}
333+
}
334+
326335
/// Peel loops after tiling.
327336
void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
328337
ArrayRef<int64_t> peeledLoops,
@@ -716,6 +725,35 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
716725
return success();
717726
}
718727

728+
mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
729+
MLIRContext *context, LinalgTransformationFilter f,
730+
LinalgPeelOptions options, PatternBenefit benefit)
731+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
732+
filter(std::move(f)), options(std::move(options)) {}
733+
734+
mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern(
735+
StringRef opName, MLIRContext *context, LinalgPeelOptions options,
736+
LinalgTransformationFilter f, PatternBenefit benefit)
737+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
738+
filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
739+
740+
LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
741+
LinalgOp linalgOp, PatternRewriter &rewriter) const {
742+
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
743+
return failure();
744+
745+
// Increase marker counter even if peeling doesn't happen for this op.
746+
filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
747+
748+
if (!options.loopsToPeelComputationFunction)
749+
return failure();
750+
751+
SmallVector<scf::ForOp, 4> loopsToPeel;
752+
options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel);
753+
peelLoops(rewriter, loopsToPeel);
754+
return success();
755+
}
756+
719757
mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
720758
MLIRContext *context, LinalgTransformationFilter f,
721759
LinalgVectorizationOptions options, PatternBenefit benefit)

0 commit comments

Comments
 (0)