Skip to content

Commit 363f6e5

Browse files
committed
[mlir][mesh] Add endomorphism simplification for all-reduce
Does transformations like all_reduce(x) + all_reduce(y) -> all_reduce(x + y) max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y)) when the all_reduce element-wise op is max.
1 parent f5e50b2 commit 363f6e5

File tree

11 files changed

+480
-0
lines changed

11 files changed

+480
-0
lines changed

mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ namespace func {
1717
class FuncOp;
1818
}
1919

20+
class RewritePatternSet;
21+
2022
namespace mesh {
2123

2224
//===----------------------------------------------------------------------===//
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
10+
#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
11+
12+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Transforms/EndomorphismSimplification.h"
15+
#include "llvm/Support/Casting.h"
16+
#include <algorithm>
17+
#include <iterator>
18+
#include <memory>
19+
#include <utility>
20+
21+
namespace mlir {
22+
namespace mesh {
23+
24+
template <typename AlgebraicOp>
25+
void populateAllReduceEndomorphismSimplificationPatterns(
26+
RewritePatternSet &patterns, Partial reduction) {
27+
auto getEndomorphismOpOperand = [](Operation *op) {
28+
auto allReduceOp = llvm::cast<AllReduceOp>(op);
29+
return &allReduceOp.getInputMutable();
30+
};
31+
auto getEndomorphismOpResult = [](Operation *op) {
32+
auto allReduceOp = llvm::cast<AllReduceOp>(op);
33+
return allReduceOp->getResult(0);
34+
};
35+
auto getAlgebraicOpOperands = [](Operation *op,
36+
SmallVector<OpOperand *> &operands) {
37+
auto algebraicOp = llvm::cast<AlgebraicOp>(op);
38+
std::transform(algebraicOp->getOpOperands().begin(),
39+
algebraicOp->getOpOperands().end(),
40+
std::back_inserter(operands),
41+
[](OpOperand &operand) { return &operand; });
42+
};
43+
auto getAlgebraicOpResult = [](Operation *op) {
44+
auto algebraicOp = llvm::cast<AlgebraicOp>(op);
45+
return algebraicOp->getResult(0);
46+
};
47+
auto isEndomorphismOp = [reduction](Operation *op,
48+
std::optional<Operation *> referenceOp) {
49+
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
50+
if (!allReduceOp ||
51+
allReduceOp.getInput().getType().getElementType() !=
52+
allReduceOp.getResult().getType().getElementType() ||
53+
allReduceOp.getReduction() != reduction) {
54+
return false;
55+
}
56+
57+
if (!referenceOp) {
58+
return true;
59+
}
60+
61+
auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
62+
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
63+
allReduceOp.getInput().getType().getElementType() ==
64+
refAllReduceOp.getInput().getType().getElementType();
65+
};
66+
auto isAlgebraicOp = [](Operation *op) {
67+
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
68+
};
69+
70+
using ConcreteEndomorphismSimplification = EndomorphismSimplification<
71+
std::decay_t<decltype(getEndomorphismOpOperand)>,
72+
std::decay_t<decltype(getEndomorphismOpResult)>,
73+
std::decay_t<decltype(getAlgebraicOpOperands)>,
74+
std::decay_t<decltype(getAlgebraicOpResult)>,
75+
std::decay_t<decltype(isEndomorphismOp)>,
76+
std::decay_t<decltype(isAlgebraicOp)>>;
77+
patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
78+
std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
79+
std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
80+
std::move(isEndomorphismOp), std::move(isAlgebraicOp),
81+
AlgebraicOp::getOperationName(), 1, patterns.getContext()));
82+
}
83+
84+
void populateSimplificationPatterns(RewritePatternSet &patterns);
85+
86+
} // namespace mesh
87+
} // namespace mlir
88+
89+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
10+
#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
11+
12+
#include "mlir/IR/IRMapping.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/IR/Value.h"
15+
#include "mlir/Support/LLVM.h"
16+
#include "llvm/ADT/SmallVector.h"
17+
#include "llvm/Support/Casting.h"
18+
#include <iterator>
19+
#include <optional>
20+
#include <type_traits>
21+
#include <utility>
22+
23+
#include "mlir/Support/LogicalResult.h"
24+
25+
namespace mlir {
26+
27+
// If `f` is an endomorphism with respect to the algebraic structure induced by
28+
// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
29+
// `f(g(x1, x2, ..., xn))`.
30+
// `g` is the algebraic operation and `f` is the endomorphism.
31+
//
32+
// Functors:
33+
// ---------
34+
// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
35+
// Returns the operand relevant to the endomorphism.
36+
// There may be other operands that are not relevant.
37+
//
38+
// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
39+
// Returns the result relevant to the endomorphism.
40+
//
41+
// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
42+
// Populates into the vector the operands relevant to the endomorphism.
43+
//
44+
// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
45+
// Return the result relevant to the endomorphism.
46+
//
47+
// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
48+
// Check if the operation is an endomorphism of the required type.
49+
// Additionally if the optional is present checks if the operations are
50+
// compatible endomorphisms.
51+
//
52+
// `IsAlgebraicOpFn`: `(Operation*) -> bool`
53+
// Check if the operation is an operation of the algebraic structure.
54+
template <typename GetEndomorphismOpOperandFn,
55+
typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
56+
typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
57+
typename IsAlgebraicOpFn>
58+
struct EndomorphismSimplification : RewritePattern {
59+
template <typename GetEndomorphismOpOperandFnArg,
60+
typename GetEndomorphismOpResultFnArg,
61+
typename GetAlgebraicOpOperandsFnArg,
62+
typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
63+
typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
64+
EndomorphismSimplification(
65+
GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
66+
GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
67+
GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
68+
GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
69+
IsEndomorphismOpFnArg &&isEndomorphismOp,
70+
IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
71+
: RewritePattern(std::forward<RewritePatternArgs>(args)...),
72+
getEndomorphismOpOperand(std::forward<GetEndomorphismOpOperandFnArg>(
73+
getEndomorphismOpOperand)),
74+
getEndomorphismOpResult(std::forward<GetEndomorphismOpResultFnArg>(
75+
getEndomorphismOpResult)),
76+
getAlgebraicOpOperands(
77+
std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands)),
78+
getAlgebraicOpResult(
79+
std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult)),
80+
isEndomorphismOp(std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp)),
81+
isAlgebraicOp(std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp)) {}
82+
83+
LogicalResult matchAndRewrite(Operation *op,
84+
PatternRewriter &rewriter) const override {
85+
if (failed(matchOp(op, algebraicOpOperands))) {
86+
return failure();
87+
}
88+
return rewriteOp(op, algebraicOpOperands, rewriter);
89+
}
90+
91+
private:
92+
LogicalResult matchOp(Operation *algebraicOp,
93+
SmallVector<OpOperand *> &algebraicOpOperands) const {
94+
if (!isAlgebraicOp(algebraicOp)) {
95+
return failure();
96+
}
97+
algebraicOpOperands.clear();
98+
getAlgebraicOpOperands(algebraicOp, algebraicOpOperands);
99+
if (algebraicOpOperands.empty()) {
100+
return failure();
101+
}
102+
103+
Operation *firstEndomorphismOp =
104+
algebraicOpOperands.front()->get().getDefiningOp();
105+
if (!firstEndomorphismOp ||
106+
!isEndomorphismOp(firstEndomorphismOp, std::nullopt)) {
107+
return failure();
108+
}
109+
OpResult firstEndomorphismOpResult =
110+
getEndomorphismOpResult(firstEndomorphismOp);
111+
if (firstEndomorphismOpResult != algebraicOpOperands.front()->get()) {
112+
return failure();
113+
}
114+
115+
for (auto operand : algebraicOpOperands) {
116+
Operation *endomorphismOp = operand->get().getDefiningOp();
117+
if (!endomorphismOp ||
118+
!isEndomorphismOp(endomorphismOp, firstEndomorphismOp)) {
119+
return failure();
120+
}
121+
}
122+
return success();
123+
}
124+
125+
LogicalResult rewriteOp(Operation *algebraicOp,
126+
const SmallVector<OpOperand *> &algebraicOpOperands,
127+
PatternRewriter &rewriter) const {
128+
irMapping.clear();
129+
for (auto operand : algebraicOpOperands) {
130+
Operation *endomorphismOp = operand->get().getDefiningOp();
131+
irMapping.map(operand->get(),
132+
getEndomorphismOpOperand(endomorphismOp)->get());
133+
}
134+
Operation *newAlgebraicOp = rewriter.clone(*algebraicOp, irMapping);
135+
136+
irMapping.clear();
137+
assert(!algebraicOpOperands.empty());
138+
Operation *firstEndomorphismOp =
139+
algebraicOpOperands[0]->get().getDefiningOp();
140+
irMapping.map(getEndomorphismOpOperand(firstEndomorphismOp)->get(),
141+
getAlgebraicOpResult(newAlgebraicOp));
142+
Operation *newEndomorphismOp =
143+
rewriter.clone(*firstEndomorphismOp, irMapping);
144+
rewriter.replaceAllUsesWith(getAlgebraicOpResult(algebraicOp),
145+
getEndomorphismOpResult(newEndomorphismOp));
146+
return success();
147+
}
148+
149+
GetEndomorphismOpOperandFn getEndomorphismOpOperand;
150+
GetEndomorphismOpResultFn getEndomorphismOpResult;
151+
GetAlgebraicOpOperandsFn getAlgebraicOpOperands;
152+
GetAlgebraicOpResultFn getAlgebraicOpResult;
153+
IsEndomorphismOpFn isEndomorphismOp;
154+
IsAlgebraicOpFn isAlgebraicOp;
155+
mutable SmallVector<OpOperand *> algebraicOpOperands;
156+
mutable IRMapping irMapping;
157+
};
158+
159+
} // namespace mlir
160+
161+
#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_

mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRMeshTransforms
2+
Simplifications.cpp
23
ShardingPropagation.cpp
34

45
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
910
MLIRShardingInterface
1011

1112
LINK_LIBS PUBLIC
13+
MLIRArithDialect
1214
MLIRFuncDialect
1315
MLIRIR
1416
MLIRMeshDialect
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
11+
12+
namespace mlir {
13+
namespace mesh {
14+
15+
void populateSimplificationPatterns(RewritePatternSet &patterns) {
16+
populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
17+
patterns, Partial::Sum);
18+
populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
19+
patterns, Partial::Sum);
20+
21+
populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
22+
patterns, Partial::Min);
23+
populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
24+
patterns, Partial::Min);
25+
populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
26+
patterns, Partial::Min);
27+
28+
populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
29+
patterns, Partial::Max);
30+
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
31+
patterns, Partial::Max);
32+
populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
33+
patterns, Partial::Max);
34+
}
35+
36+
} // namespace mesh
37+
} // namespace mlir

0 commit comments

Comments
 (0)