8
8
9
9
#include " mlir/Dialect/Mesh/Transforms/Simplifications.h"
10
10
#include " mlir/Dialect/Arith/IR/Arith.h"
11
+ #include " mlir/Dialect/Mesh/IR/MeshOps.h"
12
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
13
+ #include " mlir/IR/ImplicitLocOpBuilder.h"
14
+ #include " mlir/IR/PatternMatch.h"
15
+ #include " mlir/IR/SymbolTable.h"
16
+ #include " mlir/Support/LogicalResult.h"
17
+ #include " llvm/ADT/STLExtras.h"
18
+ #include " llvm/ADT/SmallVector.h"
19
+ #include < iterator>
20
+ #include < numeric>
21
+ #include < utility>
11
22
12
23
namespace mlir {
13
24
namespace mesh {
14
25
15
- void populateSimplificationPatterns (RewritePatternSet &patterns) {
26
+ void populateSimplificationPatterns (
27
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
16
28
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
17
29
patterns, Partial::Sum);
18
30
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
@@ -33,6 +45,85 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
33
45
patterns, Partial::Max);
34
46
35
47
// TODO: add simplifications for all-gather and other collectives.
48
+
49
+ populateFoldingPatterns (patterns, symbolTableCollection);
50
+ }
51
+
52
+ namespace {
53
+
54
+ // This folding can not be done with an operation's fold method or
55
+ // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
56
+ // symbol tables.
57
+ // We can't use DialectFoldInterface since the cache may be invalidated by some
58
+ // pass changing the referenced ClusterOp ops.
59
+ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
60
+ template <typename ... OpRewritePatternArgs>
61
+ ClusterShapeFolder (SymbolTableCollection &symbolTableCollection,
62
+ OpRewritePatternArgs &&...opRewritePatternArgs)
63
+ : OpRewritePattern(
64
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
65
+ symbolTableCollection (symbolTableCollection) {}
66
+ LogicalResult matchAndRewrite (ClusterShapeOp op,
67
+ PatternRewriter &rewriter) const override {
68
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
69
+ ClusterOp mesh =
70
+ symbolTableCollection.lookupNearestSymbolFrom <mesh::ClusterOp>(
71
+ op.getOperation (), op.getMeshAttr ());
72
+ if (!mesh) {
73
+ return failure ();
74
+ }
75
+ ArrayRef<MeshAxis> opMeshAxes = op.getAxes ();
76
+ SmallVector<MeshAxis> opAxesIota;
77
+ if (opMeshAxes.empty ()) {
78
+ opAxesIota.resize (mesh.getRank ());
79
+ std::iota (opAxesIota.begin (), opAxesIota.end (), 0 );
80
+ opMeshAxes = opAxesIota;
81
+ }
82
+ if (llvm::all_of (opMeshAxes, [&mesh](MeshAxis axis) {
83
+ return ShapedType::isDynamic (mesh.getDimSizes ()[axis]);
84
+ })) {
85
+ // All mesh dimensions are dynamic. Nothing to fold.
86
+ return failure ();
87
+ }
88
+
89
+ SmallVector<Value> newResults (op->getResults ().size ());
90
+ SmallVector<MeshAxis> newShapeOpMeshAxes;
91
+ SmallVector<size_t > newToOldResultsIndexMap;
92
+
93
+ for (size_t i = 0 ; i < opMeshAxes.size (); ++i) {
94
+ auto meshAxisSize = mesh.getDimSizes ()[opMeshAxes[i]];
95
+ if (ShapedType::isDynamic (meshAxisSize)) {
96
+ newToOldResultsIndexMap.push_back (i);
97
+ newShapeOpMeshAxes.push_back (opMeshAxes[i]);
98
+ } else {
99
+ // Fold static mesh axes.
100
+ newResults[i] = builder.create <arith::ConstantOp>(
101
+ builder.getIndexAttr (meshAxisSize));
102
+ }
103
+ }
104
+
105
+ // Leave only the dynamic mesh axes to be queried.
106
+ ClusterShapeOp newShapeOp =
107
+ builder.create <ClusterShapeOp>(mesh.getSymName (), newShapeOpMeshAxes);
108
+ for (size_t i = 0 ; i < newShapeOp->getResults ().size (); ++i) {
109
+ newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults ()[i];
110
+ }
111
+
112
+ rewriter.replaceAllUsesWith (op.getResults (), newResults);
113
+
114
+ return success ();
115
+ }
116
+
117
+ private:
118
+ SymbolTableCollection &symbolTableCollection;
119
+ };
120
+
121
+ } // namespace
122
+
123
+ void populateFoldingPatterns (RewritePatternSet &patterns,
124
+ SymbolTableCollection &symbolTableCollection) {
125
+ patterns.add <ClusterShapeFolder>(symbolTableCollection,
126
+ patterns.getContext ());
36
127
}
37
128
38
129
} // namespace mesh
0 commit comments