Skip to content

Commit 3706d6f

Browse files
committed
[mlir][mesh] Add folding of ClusterShapeOp
If the mesh has static size on some of the requested axes, the result is substituted with a constant.
1 parent 03a0bfa commit 3706d6f

File tree

7 files changed

+172
-2
lines changed

7 files changed

+172
-2
lines changed

mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include <utility>
2020

2121
namespace mlir {
22+
23+
class SymbolTable;
24+
2225
namespace mesh {
2326

2427
// If we have an algebraic op like "+" and a summing all-reduce,
@@ -103,6 +106,10 @@ void populateAllReduceEndomorphismSimplificationPatterns(
103106
}
104107

105108
void populateSimplificationPatterns(RewritePatternSet &patterns);
109+
// It is invalid to change ops that declare symbols during the application of
110+
// these patterns, because symbolTable is used to cache them.
111+
void populateFoldingPatterns(RewritePatternSet &patterns,
112+
SymbolTableCollection &symbolTable);
106113

107114
} // namespace mesh
108115
} // namespace mlir

mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88

99
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
1010
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
12+
#include "mlir/IR/BuiltinTypeInterfaces.h"
13+
#include "mlir/IR/ImplicitLocOpBuilder.h"
14+
#include "mlir/IR/PatternMatch.h"
15+
#include "mlir/IR/SymbolTable.h"
16+
#include "mlir/Support/LogicalResult.h"
17+
#include "llvm/ADT/STLExtras.h"
18+
#include "llvm/ADT/SmallVector.h"
19+
#include <iterator>
20+
#include <numeric>
21+
#include <utility>
1122

1223
namespace mlir {
1324
namespace mesh {
@@ -35,5 +46,80 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
3546
// TODO: add simplifications for all-gather and other collectives.
3647
}
3748

49+
namespace {
50+
51+
// This folding can not be done with an operation's fold method or
52+
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
53+
// symbol tables.
54+
// We can't use DialectFoldInterface since the cache may be invalidated by some
55+
// pass changing the referenced ClusterOp ops.
56+
struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
57+
template <typename... OpRewritePatternArgs>
58+
ClusterShapeFolder(SymbolTableCollection &symbolTable,
59+
OpRewritePatternArgs &&...opRewritePatternArgs)
60+
: OpRewritePattern(
61+
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
62+
symbolTable(symbolTable) {}
63+
LogicalResult matchAndRewrite(ClusterShapeOp op,
64+
PatternRewriter &rewriter) const override {
65+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
66+
ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
67+
op.getOperation(), op.getMeshAttr());
68+
if (!mesh) {
69+
return failure();
70+
}
71+
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
72+
SmallVector<MeshAxis> opAxesIota;
73+
if (opMeshAxes.empty()) {
74+
opAxesIota.resize(mesh.getRank());
75+
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
76+
opMeshAxes = opAxesIota;
77+
}
78+
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
79+
return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
80+
})) {
81+
// All mesh dimensions are dynamic. Nothing to fold.
82+
return failure();
83+
}
84+
85+
SmallVector<Value> newResults(op->getResults().size());
86+
SmallVector<MeshAxis> newShapeOpMeshAxes;
87+
SmallVector<size_t> newToOldResultsIndexMap;
88+
89+
for (size_t i = 0; i < opMeshAxes.size(); ++i) {
90+
auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
91+
if (ShapedType::isDynamic(meshAxisSize)) {
92+
newToOldResultsIndexMap.push_back(i);
93+
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
94+
} else {
95+
// Fold static mesh axes.
96+
newResults[i] = builder.create<arith::ConstantOp>(
97+
builder.getIndexAttr(meshAxisSize));
98+
}
99+
}
100+
101+
// Leave only the dynamic mesh axes to be queried.
102+
ClusterShapeOp newShapeOp =
103+
builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104+
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105+
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
106+
}
107+
108+
rewriter.replaceAllUsesWith(op.getResults(), newResults);
109+
110+
return success();
111+
}
112+
113+
private:
114+
SymbolTableCollection &symbolTable;
115+
};
116+
117+
} // namespace
118+
119+
void populateFoldingPatterns(RewritePatternSet &patterns,
120+
SymbolTableCollection &symbolTable) {
121+
patterns.add<ClusterShapeFolder>(symbolTable, patterns.getContext());
122+
}
123+
38124
} // namespace mesh
39125
} // namespace mlir

mlir/test/Dialect/Mesh/folding.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-opt -test-mesh-folding %s | FileCheck %s
2+
3+
mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
4+
mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
5+
6+
// CHECK-LABEL: func.func @cluster_shape_op_folding
7+
func.func @cluster_shape_op_folding() -> (index, index) {
8+
// CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
9+
// CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
10+
%0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
11+
// CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
12+
return %0#0, %0#1 : index, index
13+
}
14+
15+
// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
16+
func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
17+
// CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
18+
// CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
19+
%0:2 = mesh.cluster_shape @mesh1 : index, index
20+
// CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
21+
return %0#0, %0#1 : index, index
22+
}

mlir/test/lib/Dialect/Mesh/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Exclude tests from libMLIR.so
2-
add_mlir_library(MLIRMeshTestSimplifications
2+
add_mlir_library(MLIRMeshTest
3+
TestFolding.cpp
34
TestReshardingSpmdization.cpp
45
TestSimplifications.cpp
56

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===- TestSimplification.cpp - Test simplification -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
11+
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
12+
#include "mlir/IR/Diagnostics.h"
13+
#include "mlir/IR/SymbolTable.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Support/LogicalResult.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
#include <memory>
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
23+
struct TestMeshFoldingPass
24+
: public PassWrapper<TestMeshFoldingPass, OperationPass<>> {
25+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshFoldingPass)
26+
27+
void runOnOperation() override;
28+
void getDependentDialects(DialectRegistry &registry) const override {
29+
registry.insert<mesh::MeshDialect>();
30+
}
31+
StringRef getArgument() const final { return "test-mesh-folding"; }
32+
StringRef getDescription() const final { return "Test mesh folding."; }
33+
};
34+
} // namespace
35+
36+
void TestMeshFoldingPass::runOnOperation() {
37+
RewritePatternSet patterns(&getContext());
38+
SymbolTableCollection symbolTables;
39+
mesh::populateFoldingPatterns(patterns, symbolTables);
40+
if (failed(
41+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
42+
getOperation()->emitError()
43+
<< "Rewrite patter application did not converge.";
44+
return signalPassFailure();
45+
}
46+
}
47+
48+
namespace mlir {
49+
namespace test {
50+
void registerTestMeshFoldingPass() { PassRegistration<TestMeshFoldingPass>(); }
51+
} // namespace test
52+
} // namespace mlir

mlir/tools/mlir-opt/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ if(MLIR_INCLUDE_TESTS)
2626
MLIRLoopLikeInterfaceTestPasses
2727
MLIRMathTestPasses
2828
MLIRMemRefTestPasses
29-
MLIRMeshTestSimplifications
29+
MLIRMeshTest
3030
MLIRNVGPUTestPasses
3131
MLIRSCFTestPasses
3232
MLIRShapeTestPasses

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
118118
void registerTestMathPolynomialApproximationPass();
119119
void registerTestMemRefDependenceCheck();
120120
void registerTestMemRefStrideCalculation();
121+
void registerTestMeshFoldingPass();
121122
void registerTestMeshSimplificationsPass();
122123
void registerTestMeshReshardingSpmdizationPass();
123124
void registerTestNextAccessPass();
@@ -240,6 +241,7 @@ void registerTestPasses() {
240241
mlir::test::registerTestMathPolynomialApproximationPass();
241242
mlir::test::registerTestMemRefDependenceCheck();
242243
mlir::test::registerTestMemRefStrideCalculation();
244+
mlir::test::registerTestMeshFoldingPass();
243245
mlir::test::registerTestMeshSimplificationsPass();
244246
mlir::test::registerTestMeshReshardingSpmdizationPass();
245247
mlir::test::registerTestNextAccessPass();

0 commit comments

Comments
 (0)