Skip to content

Commit 7eef40d

Browse files
committed
fixup! fixup! [mlir][vector] Group tests for re-order patterns
Group all re-order patterns under `populateSinkVectorOpsPatterns`. Rename test flag and test pass accordingly.
1 parent 90797f1 commit 7eef40d

File tree

5 files changed

+30
-29
lines changed

5 files changed

+30
-29
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,22 @@ void populateVectorTransferFullPartialPatterns(
144144
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
145145
RewritePatternSet &patterns, PatternBenefit benefit = 1);
146146

147-
/// Patterns that remove redundant vector broadcasts.
148-
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
149-
PatternBenefit benefit = 1);
150-
151-
/// Patterns that re-order transpose Ops.
152-
void populateReoderVectorTransposePatterns(RewritePatternSet &patterns,
153-
PatternBenefit benefit = 1);
147+
/// Patterns that remove redundant Vector Ops by re-ordering them with
148+
/// e.g. elementwise Ops:
149+
/// ```
150+
/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
151+
/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
152+
/// %r = arith.addf %at, %bt : vector<2x4xf32>
153+
/// ```
154+
/// gets converted to:
155+
/// ```
156+
/// %0 = arith.addf %a, %b : vector<4x2xf32>
157+
/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
158+
/// ```
159+
/// At the moment, these patterns are limited to vector.broadcast and
160+
/// vector.transpose.
161+
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
162+
PatternBenefit benefit = 1);
154163

155164
/// Patterns that fold chained vector reductions. These patterns assume that
156165
/// elementwise operations (e.g., `arith.addf` with vector operands) are

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3452,8 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
34523452
if (!getDisableMultiReductionToContractPatterns())
34533453
vector::populateVectorReductionToContractPatterns(patterns);
34543454

3455-
vector::populateReoderVectorTransposePatterns(patterns);
3456-
vector::populateSinkVectorBroadcastPatterns(patterns);
3455+
vector::populateSinkVectorOpsPatterns(patterns);
34573456

34583457
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
34593458
linalg::LinalgCopyVTWForwardingPattern>(ctx,

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,15 +2042,10 @@ void mlir::vector::
20422042
benefit);
20432043
}
20442044

2045-
void mlir::vector::populateSinkVectorBroadcastPatterns(
2046-
RewritePatternSet &patterns, PatternBenefit benefit) {
2047-
patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
2048-
patterns.getContext(), benefit);
2049-
}
2050-
2051-
void mlir::vector::populateReoderVectorTransposePatterns(
2052-
RewritePatternSet &patterns, PatternBenefit benefit) {
2053-
patterns.add<ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
2045+
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
2046+
PatternBenefit benefit) {
2047+
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2048+
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
20542049
benefit);
20552050
}
20562051

mlir/test/Dialect/Vector/vector-reorder.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-vector-reorder-patterns -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
22

33
//-----------------------------------------------------------------------------
44
// [Pattern: ReorderElementwiseOpsOnBroadcast]

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -374,19 +374,18 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
374374
}
375375
};
376376

377-
struct TestVectorReorderPatterns
378-
: public PassWrapper<TestVectorReorderPatterns,
379-
OperationPass<func::FuncOp>> {
380-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorReorderPatterns)
377+
struct TestVectorSinkPatterns
378+
: public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
379+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
381380

382-
TestVectorReorderPatterns() = default;
383-
TestVectorReorderPatterns(const TestVectorReorderPatterns &pass) = default;
381+
TestVectorSinkPatterns() = default;
382+
TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
384383

385384
void getDependentDialects(DialectRegistry &registry) const override {
386385
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
387386
}
388387

389-
StringRef getArgument() const final { return "test-vector-reorder-patterns"; }
388+
StringRef getArgument() const final { return "test-vector-sink-patterns"; }
390389

391390
StringRef getDescription() const final {
392391
return "Test lowering patterns that eliminate redundant brodacast "
@@ -395,8 +394,7 @@ struct TestVectorReorderPatterns
395394

396395
void runOnOperation() override {
397396
RewritePatternSet patterns(&getContext());
398-
populateSinkVectorBroadcastPatterns(patterns);
399-
populateReoderVectorTransposePatterns(patterns);
397+
populateSinkVectorOpsPatterns(patterns);
400398
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
401399
}
402400
};
@@ -921,7 +919,7 @@ void registerTestVectorLowerings() {
921919

922920
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
923921

924-
PassRegistration<TestVectorReorderPatterns>();
922+
PassRegistration<TestVectorSinkPatterns>();
925923

926924
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
927925

0 commit comments

Comments
 (0)