Skip to content

Commit aad2791

Browse files
committed
Add folding patterns to all simplification patterns
1 parent 3706d6f commit aad2791

File tree

7 files changed

+25
-70
lines changed

7 files changed

+25
-70
lines changed

mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
namespace mlir {
2222

23-
class SymbolTable;
23+
class SymbolTableCollection;
2424

2525
namespace mesh {
2626

@@ -105,11 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
105105
AlgebraicOp::getOperationName(), 1, patterns.getContext()));
106106
}
107107

108-
void populateSimplificationPatterns(RewritePatternSet &patterns);
109108
// It is invalid to change ops that declare symbols during the application of
110-
// these patterns, because symbolTable is used to cache them.
109+
// these patterns, because symbolTableCollection is used to cache them.
110+
void populateSimplificationPatterns(
111+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
111112
void populateFoldingPatterns(RewritePatternSet &patterns,
112-
SymbolTableCollection &symbolTable);
113+
SymbolTableCollection &symbolTableCollection);
113114

114115
} // namespace mesh
115116
} // namespace mlir

mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
namespace mlir {
2424
namespace mesh {
2525

26-
void populateSimplificationPatterns(RewritePatternSet &patterns) {
26+
void populateSimplificationPatterns(
27+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
2728
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
2829
patterns, Partial::Sum);
2930
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
@@ -44,6 +45,8 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
4445
patterns, Partial::Max);
4546

4647
// TODO: add simplifications for all-gather and other collectives.
48+
49+
populateFoldingPatterns(patterns, symbolTableCollection);
4750
}
4851

4952
namespace {
@@ -55,16 +58,17 @@ namespace {
5558
// pass changing the referenced ClusterOp ops.
5659
struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
5760
template <typename... OpRewritePatternArgs>
58-
ClusterShapeFolder(SymbolTableCollection &symbolTable,
61+
ClusterShapeFolder(SymbolTableCollection &symbolTableCollection,
5962
OpRewritePatternArgs &&...opRewritePatternArgs)
6063
: OpRewritePattern(
6164
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
62-
symbolTable(symbolTable) {}
65+
symbolTableCollection(symbolTableCollection) {}
6366
LogicalResult matchAndRewrite(ClusterShapeOp op,
6467
PatternRewriter &rewriter) const override {
6568
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
66-
ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
67-
op.getOperation(), op.getMeshAttr());
69+
ClusterOp mesh =
70+
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
71+
op.getOperation(), op.getMeshAttr());
6872
if (!mesh) {
6973
return failure();
7074
}
@@ -111,14 +115,15 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
111115
}
112116

113117
private:
114-
SymbolTableCollection &symbolTable;
118+
SymbolTableCollection &symbolTableCollection;
115119
};
116120

117121
} // namespace
118122

119123
void populateFoldingPatterns(RewritePatternSet &patterns,
120-
SymbolTableCollection &symbolTable) {
121-
patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
124+
SymbolTableCollection &symbolTableCollection) {
125+
patterns.add<ClusterShapeFolder>(symbolTableCollection,
126+
patterns.getContext());
122127
}
123128

124129
} // namespace mesh

mlir/test/Dialect/Mesh/folding.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
1+
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
22

33
mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
44
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)

mlir/test/lib/Dialect/Mesh/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRMeshTest
3-
TestFolding.cpp
43
TestReshardingSpmdization.cpp
54
TestSimplifications.cpp
65

mlir/test/lib/Dialect/Mesh/TestFolding.cpp

Lines changed: 0 additions & 52 deletions
This file was deleted.

mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1111
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
12+
#include "mlir/IR/SymbolTable.h"
1213
#include "mlir/Pass/Pass.h"
1314
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1415

@@ -30,8 +31,11 @@ struct TestMeshSimplificationsPass
3031

3132
void TestMeshSimplificationsPass::runOnOperation() {
3233
RewritePatternSet patterns(&getContext());
33-
mesh::populateSimplificationPatterns(patterns);
34-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
34+
SymbolTableCollection symbolTableCollection;
35+
mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
36+
LogicalResult status =
37+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
38+
assert(succeeded(status) && "Rewrite patters application did not converge.");
3539
}
3640

3741
namespace mlir {

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ void registerTestMathAlgebraicSimplificationPass();
118118
void registerTestMathPolynomialApproximationPass();
119119
void registerTestMemRefDependenceCheck();
120120
void registerTestMemRefStrideCalculation();
121-
void registerTestMeshFoldingPass();
122121
void registerTestMeshSimplificationsPass();
123122
void registerTestMeshReshardingSpmdizationPass();
124123
void registerTestNextAccessPass();
@@ -241,7 +240,6 @@ void registerTestPasses() {
241240
mlir::test::registerTestMathPolynomialApproximationPass();
242241
mlir::test::registerTestMemRefDependenceCheck();
243242
mlir::test::registerTestMemRefStrideCalculation();
244-
mlir::test::registerTestMeshFoldingPass();
245243
mlir::test::registerTestMeshSimplificationsPass();
246244
mlir::test::registerTestMeshReshardingSpmdizationPass();
247245
mlir::test::registerTestNextAccessPass();

0 commit comments

Comments
 (0)