Skip to content

Commit 5279e11

Browse files
committed
[mlir][linalg] Retire Linalg's Vectorization Pattern
This revision retires the LinalgCodegenStrategy vectorization pattern. Please see the context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785. This revision improves the transform dialect's VectorizeOp in different ways below: - Adds LinalgDialect as a dependent dialect. When `transform.structured.vectorize` vectorizes `tensor.pad`, it generates `linalg.init_tensor`. In this case, linalg dialect must be registered. - Inserts CopyVectorizationPattern in order to vectorize `memref.copy`. - Creates two attributes: `disable_multi_reduction_to_contract_patterns` and `disable_transfer_permutation_map_lowering_patterns`. They are limiting the power of vectorization and are currently intended for testing purposes. It also removes some of the "CHECK: vector.transfer_write" in the vectorization.mlir test. They are redundant writes, at the end of the code there is a rewrite to the same place. Transform dialect no longer generates them. Depends on D133684 that retires the LinalgCodegenStrategy vectorization pass. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D133699
1 parent 51e0946 commit 5279e11

File tree

6 files changed

+499
-77
lines changed

6 files changed

+499
-77
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
767767
Note that this transformation is invalidating the handles to any payload IR
768768
operation that is contained inside the vectorization target.
769769

770+
`disable_multi_reduction_to_contract_patterns` and
771+
`disable_transfer_permutation_map_lowering_patterns` limits the power of
772+
vectorization. They are currently intended for testing purposes.
773+
770774
#### Return modes:
771775

772776
This operation produces `definiteFailure` if vectorization fails for any
@@ -776,7 +780,9 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
776780
}];
777781

778782
let arguments = (ins PDL_Operation:$target,
779-
DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding);
783+
DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding,
784+
DefaultValuedAttr<BoolAttr, "false">:$disable_multi_reduction_to_contract_patterns,
785+
DefaultValuedAttr<BoolAttr, "false">:$disable_transfer_permutation_map_lowering_patterns);
780786
let results = (outs PDL_Operation:$transformed);
781787

782788
let assemblyFormat = "$target attr-dict";

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -926,31 +926,6 @@ struct LinalgPeelingPattern : public OpInterfaceRewritePattern<LinalgOp> {
926926
/// Empty for now, used for SFINAE purposes only.
927927
struct LinalgVectorizationOptions {};
928928

929-
/// `filter` controls LinalgTransformMarker matching and update when specified.
930-
/// See `vectorizeLinalgOp` for more details.
931-
struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
932-
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
933-
LinalgVectorizationPattern(
934-
MLIRContext *context,
935-
LinalgTransformationFilter f = LinalgTransformationFilter(),
936-
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
937-
PatternBenefit benefit = 1);
938-
939-
/// Construct a pattern specifically applied to `opName`.
940-
LinalgVectorizationPattern(
941-
StringRef opName, MLIRContext *context,
942-
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
943-
LinalgTransformationFilter f = LinalgTransformationFilter(),
944-
PatternBenefit benefit = 1);
945-
946-
LogicalResult matchAndRewrite(LinalgOp linalgOp,
947-
PatternRewriter &rewriter) const override;
948-
949-
private:
950-
/// LinalgTransformMarker handles special attribute manipulations.
951-
LinalgTransformationFilter filter;
952-
};
953-
954929
/// `filter` controls LinalgTransformMarker matching and update when specified.
955930
/// See `vectorizeLinalgOp` for more details.
956931
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
@@ -1335,18 +1310,6 @@ class VectorizationPatterns<> {
13351310
const LinalgTransformationFilter &f) {}
13361311
};
13371312

1338-
template <typename OpTy, typename... OpTypes>
1339-
class VectorizationPatterns<OpTy, OpTypes...> {
1340-
public:
1341-
static void insert(RewritePatternSet &patterns,
1342-
const LinalgVectorizationOptions &options,
1343-
const LinalgTransformationFilter &f) {
1344-
patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
1345-
patterns.getContext(), options, f);
1346-
VectorizationPatterns<OpTypes...>::insert(patterns, options, f);
1347-
}
1348-
};
1349-
13501313
template <typename... OpTypes>
13511314
class TilingPatterns;
13521315

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,22 @@ LogicalResult TileToForeachThreadOp::verify() {
11661166
// VectorizeOp
11671167
//===----------------------------------------------------------------------===//
11681168

1169+
namespace {
1170+
/// This is an helper only to call vectorize via a pattern inside of
1171+
/// VectorizeOp::applyToOne.
1172+
struct VectorizationPattern : public RewritePattern {
1173+
explicit VectorizationPattern(MLIRContext *context)
1174+
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1175+
LogicalResult matchAndRewrite(Operation *op,
1176+
PatternRewriter &rewriter) const override {
1177+
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
1178+
if (!linalgOp)
1179+
return failure();
1180+
return vectorize(rewriter, linalgOp);
1181+
}
1182+
};
1183+
} // namespace
1184+
11691185
DiagnosedSilenceableFailure
11701186
transform::VectorizeOp::applyToOne(Operation *target,
11711187
SmallVectorImpl<Operation *> &results,
@@ -1178,15 +1194,22 @@ transform::VectorizeOp::applyToOne(Operation *target,
11781194

11791195
MLIRContext *ctx = getContext();
11801196
RewritePatternSet patterns(ctx);
1181-
patterns.add<LinalgVectorizationPattern>(ctx);
1197+
patterns.add<VectorizationPattern>(ctx);
1198+
1199+
if (!getDisableTransferPermutationMapLoweringPatterns())
1200+
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
1201+
1202+
if (!getDisableMultiReductionToContractPatterns())
1203+
vector::populateVectorReductionToContractPatterns(patterns);
11821204

1183-
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
1184-
vector::populateVectorReductionToContractPatterns(patterns);
11851205
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
11861206
linalg::LinalgCopyVTWForwardingPattern>(ctx,
11871207
/*benefit=*/2);
11881208
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
11891209
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1210+
1211+
patterns.add<CopyVectorizationPattern>(ctx);
1212+
11901213
if (getVectorizePadding())
11911214
linalg::populatePadOpVectorizationPatterns(patterns);
11921215

@@ -1212,7 +1235,7 @@ class LinalgTransformDialectExtension
12121235

12131236
void init() {
12141237
declareDependentDialect<pdl::PDLDialect>();
1215-
1238+
declareDependentDialect<LinalgDialect>();
12161239
declareGeneratedDialect<AffineDialect>();
12171240
declareGeneratedDialect<arith::ArithmeticDialect>();
12181241
declareGeneratedDialect<scf::SCFDialect>();

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -590,25 +590,6 @@ LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
590590
return success();
591591
}
592592

593-
mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
594-
MLIRContext *context, LinalgTransformationFilter f,
595-
LinalgVectorizationOptions options, PatternBenefit benefit)
596-
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
597-
filter(std::move(f)) {}
598-
599-
mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
600-
StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
601-
LinalgTransformationFilter f, PatternBenefit benefit)
602-
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
603-
filter(f.addOpNameFilter(opName)) {}
604-
605-
LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
606-
LinalgOp linalgOp, PatternRewriter &rewriter) const {
607-
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
608-
return failure();
609-
return vectorize(rewriter, linalgOp);
610-
}
611-
612593
LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
613594
memref::CopyOp copyOp, PatternRewriter &rewriter) const {
614595
return vectorizeCopy(rewriter, copyOp);

0 commit comments

Comments
 (0)