Skip to content

Commit ab59037

Browse files
authored
[mlir][mesh] Add folding of ClusterShapeOp (llvm#77033)
If the mesh has static size on some of the requested axes, the result is substituted with a constant.
1 parent cd101ab commit ab59037

File tree

6 files changed

+131
-6
lines changed

6 files changed

+131
-6
lines changed

mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include <utility>
2020

2121
namespace mlir {
22+
23+
class SymbolTableCollection;
24+
2225
namespace mesh {
2326

2427
// If we have an algebraic op like "+" and a summing all-reduce,
@@ -102,7 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
102105
AlgebraicOp::getOperationName(), 1, patterns.getContext()));
103106
}
104107

105-
void populateSimplificationPatterns(RewritePatternSet &patterns);
108+
// It is invalid to change ops that declare symbols during the application of
109+
// these patterns, because symbolTableCollection is used to cache them.
110+
void populateSimplificationPatterns(
111+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
112+
void populateFoldingPatterns(RewritePatternSet &patterns,
113+
SymbolTableCollection &symbolTableCollection);
106114

107115
} // namespace mesh
108116
} // namespace mlir

mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,23 @@
88

99
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
1010
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
12+
#include "mlir/IR/BuiltinTypeInterfaces.h"
13+
#include "mlir/IR/ImplicitLocOpBuilder.h"
14+
#include "mlir/IR/PatternMatch.h"
15+
#include "mlir/IR/SymbolTable.h"
16+
#include "mlir/Support/LogicalResult.h"
17+
#include "llvm/ADT/STLExtras.h"
18+
#include "llvm/ADT/SmallVector.h"
19+
#include <iterator>
20+
#include <numeric>
21+
#include <utility>
1122

1223
namespace mlir {
1324
namespace mesh {
1425

15-
void populateSimplificationPatterns(RewritePatternSet &patterns) {
26+
void populateSimplificationPatterns(
27+
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
1628
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
1729
patterns, Partial::Sum);
1830
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
@@ -33,6 +45,85 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
3345
patterns, Partial::Max);
3446

3547
// TODO: add simplifications for all-gather and other collectives.
48+
49+
populateFoldingPatterns(patterns, symbolTableCollection);
50+
}
51+
52+
namespace {
53+
54+
// This folding can not be done with an operation's fold method or
55+
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
56+
// symbol tables.
57+
// We can't use DialectFoldInterface since the cache may be invalidated by some
58+
// pass changing the referenced ClusterOp ops.
59+
struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
60+
template <typename... OpRewritePatternArgs>
61+
ClusterShapeFolder(SymbolTableCollection &symbolTableCollection,
62+
OpRewritePatternArgs &&...opRewritePatternArgs)
63+
: OpRewritePattern(
64+
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
65+
symbolTableCollection(symbolTableCollection) {}
66+
LogicalResult matchAndRewrite(ClusterShapeOp op,
67+
PatternRewriter &rewriter) const override {
68+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
69+
ClusterOp mesh =
70+
symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
71+
op.getOperation(), op.getMeshAttr());
72+
if (!mesh) {
73+
return failure();
74+
}
75+
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
76+
SmallVector<MeshAxis> opAxesIota;
77+
if (opMeshAxes.empty()) {
78+
opAxesIota.resize(mesh.getRank());
79+
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
80+
opMeshAxes = opAxesIota;
81+
}
82+
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
83+
return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
84+
})) {
85+
// All mesh dimensions are dynamic. Nothing to fold.
86+
return failure();
87+
}
88+
89+
SmallVector<Value> newResults(op->getResults().size());
90+
SmallVector<MeshAxis> newShapeOpMeshAxes;
91+
SmallVector<size_t> newToOldResultsIndexMap;
92+
93+
for (size_t i = 0; i < opMeshAxes.size(); ++i) {
94+
auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
95+
if (ShapedType::isDynamic(meshAxisSize)) {
96+
newToOldResultsIndexMap.push_back(i);
97+
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
98+
} else {
99+
// Fold static mesh axes.
100+
newResults[i] = builder.create<arith::ConstantOp>(
101+
builder.getIndexAttr(meshAxisSize));
102+
}
103+
}
104+
105+
// Leave only the dynamic mesh axes to be queried.
106+
ClusterShapeOp newShapeOp =
107+
builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
108+
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
109+
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
110+
}
111+
112+
rewriter.replaceAllUsesWith(op.getResults(), newResults);
113+
114+
return success();
115+
}
116+
117+
private:
118+
SymbolTableCollection &symbolTableCollection;
119+
};
120+
121+
} // namespace
122+
123+
void populateFoldingPatterns(RewritePatternSet &patterns,
124+
SymbolTableCollection &symbolTableCollection) {
125+
patterns.add<ClusterShapeFolder>(symbolTableCollection,
126+
patterns.getContext());
36127
}
37128

38129
} // namespace mesh

mlir/test/Dialect/Mesh/folding.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
2+
3+
mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
4+
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
5+
6+
// CHECK-LABEL: func.func @cluster_shape_op_folding
7+
func.func @cluster_shape_op_folding() -> (index, index) {
8+
// CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
9+
// CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
10+
%0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
11+
// CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
12+
return %0#0, %0#1 : index, index
13+
}
14+
15+
// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
16+
func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
17+
// CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
18+
// CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
19+
%0:2 = mesh.cluster_shape @mesh1 : index, index
20+
// CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
21+
return %0#0, %0#1 : index, index
22+
}

mlir/test/lib/Dialect/Mesh/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Exclude tests from libMLIR.so
2-
add_mlir_library(MLIRMeshTestSimplifications
2+
add_mlir_library(MLIRMeshTest
33
TestReshardingSpmdization.cpp
44
TestSimplifications.cpp
55

mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1111
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
12+
#include "mlir/IR/SymbolTable.h"
1213
#include "mlir/Pass/Pass.h"
1314
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1415

@@ -30,8 +31,11 @@ struct TestMeshSimplificationsPass
3031

3132
void TestMeshSimplificationsPass::runOnOperation() {
3233
RewritePatternSet patterns(&getContext());
33-
mesh::populateSimplificationPatterns(patterns);
34-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
34+
SymbolTableCollection symbolTableCollection;
35+
mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
36+
LogicalResult status =
37+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
38+
assert(succeeded(status) && "Rewrite patters application did not converge.");
3539
}
3640

3741
namespace mlir {

mlir/tools/mlir-opt/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
2626
MLIRLoopLikeInterfaceTestPasses
2727
MLIRMathTestPasses
2828
MLIRMemRefTestPasses
29-
MLIRMeshTestSimplifications
29+
MLIRMeshTest
3030
MLIRNVGPUTestPasses
3131
MLIRSCFTestPasses
3232
MLIRShapeTestPasses

0 commit comments

Comments
 (0)