Skip to content

Commit 4b34467

Browse files
authored
[mlir][mesh] Add endomorphism simplification for all-reduce (#73150)
Does transformations like all_reduce(x) + all_reduce(y) -> all_reduce(x + y) max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y)) when the all_reduce element-wise op is max. Added general rewrite pattern HomomorphismSimplification and EndomorphismSimplification that encapsulate the general algorithm. Made specialization for all-reduce with respect to addf, addi, minsi, maxsi, minimumf and maximumf in the Arithmetic dialect.
1 parent 8063622 commit 4b34467

File tree

11 files changed

+659
-0
lines changed

11 files changed

+659
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
10+
#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
11+
12+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Transforms/EndomorphismSimplification.h"
15+
#include "llvm/Support/Casting.h"
16+
#include <algorithm>
17+
#include <iterator>
18+
#include <memory>
19+
#include <utility>
20+
21+
namespace mlir {
22+
namespace mesh {
23+
24+
// If we have an algebraic op like "+" and a summing all-reduce,
25+
// `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
26+
// `all_reduce_sum(x + y)`.
27+
//
28+
// Another example with `min`.
29+
// `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
30+
// `all_reduce_min(min(x, y))`.
31+
//
32+
// Works only with algebraic ops that have all their operands relevant
33+
// to the all-reduce endomorphism.
34+
// Will not work with some op `f(x, y, z)` where only `x` and `y` form
35+
// the algebraic structure.
36+
template <typename AlgebraicOp>
37+
void populateAllReduceEndomorphismSimplificationPatterns(
38+
RewritePatternSet &patterns, Partial reduction) {
39+
auto getEndomorphismOpOperand = [](Operation *op) {
40+
auto allReduceOp = llvm::cast<AllReduceOp>(op);
41+
return &allReduceOp.getInputMutable();
42+
};
43+
auto getEndomorphismOpResult = [](Operation *op) {
44+
auto allReduceOp = llvm::cast<AllReduceOp>(op);
45+
return allReduceOp->getResult(0);
46+
};
47+
auto getAlgebraicOpOperands = [](Operation *op,
48+
SmallVector<OpOperand *> &operands) {
49+
auto algebraicOp = llvm::cast<AlgebraicOp>(op);
50+
std::transform(algebraicOp->getOpOperands().begin(),
51+
algebraicOp->getOpOperands().end(),
52+
std::back_inserter(operands),
53+
[](OpOperand &operand) { return &operand; });
54+
};
55+
auto getAlgebraicOpResult = [](Operation *op) {
56+
auto algebraicOp = llvm::cast<AlgebraicOp>(op);
57+
return algebraicOp->getResult(0);
58+
};
59+
auto isEndomorphismOp = [reduction](Operation *op,
60+
std::optional<Operation *> referenceOp) {
61+
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
62+
if (!allReduceOp ||
63+
allReduceOp.getInput().getType().getElementType() !=
64+
allReduceOp.getResult().getType().getElementType() ||
65+
allReduceOp.getReduction() != reduction) {
66+
return false;
67+
}
68+
69+
// Dont't use simplify if the all-reduce is used other than by the
70+
// algebraic op.
71+
// TODO: maybe handle this by an additional pass that later reverses the
72+
// simplification if there are other uses left other optimizations have
73+
// been done.
74+
if (!allReduceOp->hasOneUse()) {
75+
return false;
76+
}
77+
78+
if (!referenceOp) {
79+
return true;
80+
}
81+
82+
auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
83+
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
84+
allReduceOp.getInput().getType().getElementType() ==
85+
refAllReduceOp.getInput().getType().getElementType();
86+
};
87+
auto isAlgebraicOp = [](Operation *op) {
88+
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
89+
};
90+
91+
using ConcreteEndomorphismSimplification = EndomorphismSimplification<
92+
std::decay_t<decltype(getEndomorphismOpOperand)>,
93+
std::decay_t<decltype(getEndomorphismOpResult)>,
94+
std::decay_t<decltype(getAlgebraicOpOperands)>,
95+
std::decay_t<decltype(getAlgebraicOpResult)>,
96+
std::decay_t<decltype(isEndomorphismOp)>,
97+
std::decay_t<decltype(isAlgebraicOp)>>;
98+
patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
99+
std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
100+
std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
101+
std::move(isEndomorphismOp), std::move(isAlgebraicOp),
102+
AlgebraicOp::getOperationName(), 1, patterns.getContext()));
103+
}
104+
105+
void populateSimplificationPatterns(RewritePatternSet &patterns);
106+
107+
} // namespace mesh
108+
} // namespace mlir
109+
110+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
10+
#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
11+
12+
#include "mlir/Transforms/HomomorphismSimplification.h"
13+
14+
namespace mlir {
15+
16+
namespace detail {
17+
struct CreateAlgebraicOpForEndomorphismSimplification {
18+
Operation *operator()(Operation *op, IRMapping &operandsRemapping,
19+
PatternRewriter &rewriter) const {
20+
return rewriter.clone(*op, operandsRemapping);
21+
}
22+
};
23+
} // namespace detail
24+
25+
// If `f` is an endomorphism with respect to the algebraic structure induced by
26+
// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
27+
// `f(g(x1, x2, ..., xn))`.
28+
// `g` is the algebraic operation and `f` is the endomorphism.
29+
//
30+
// Functors:
31+
// ---------
32+
// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
33+
// Returns the operand relevant to the endomorphism.
34+
// There may be other operands that are not relevant.
35+
//
36+
// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
37+
// Returns the result relevant to the endomorphism.
38+
//
39+
// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
40+
// Populates into the vector the operands relevant to the endomorphism.
41+
//
42+
// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
43+
// Return the result relevant to the endomorphism.
44+
//
45+
// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
46+
// Check if the operation is an endomorphism of the required type.
47+
// Additionally if the optional is present checks if the operations are
48+
// compatible endomorphisms.
49+
//
50+
// `IsAlgebraicOpFn`: `(Operation*) -> bool`
51+
// Check if the operation is an operation of the algebraic structure.
52+
template <typename GetEndomorphismOpOperandFn,
53+
typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
54+
typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
55+
typename IsAlgebraicOpFn>
56+
struct EndomorphismSimplification
57+
: HomomorphismSimplification<
58+
GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
59+
GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
60+
GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
61+
detail::CreateAlgebraicOpForEndomorphismSimplification> {
62+
template <typename GetEndomorphismOpOperandFnArg,
63+
typename GetEndomorphismOpResultFnArg,
64+
typename GetAlgebraicOpOperandsFnArg,
65+
typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
66+
typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
67+
EndomorphismSimplification(
68+
GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
69+
GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
70+
GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
71+
GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
72+
IsEndomorphismOpFnArg &&isEndomorphismOp,
73+
IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
74+
: HomomorphismSimplification<
75+
GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
76+
GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
77+
GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
78+
detail::CreateAlgebraicOpForEndomorphismSimplification>(
79+
std::forward<GetEndomorphismOpOperandFnArg>(
80+
getEndomorphismOpOperand),
81+
std::forward<GetEndomorphismOpResultFnArg>(getEndomorphismOpResult),
82+
std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands),
83+
std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
84+
std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
85+
std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp),
86+
std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp),
87+
detail::CreateAlgebraicOpForEndomorphismSimplification(),
88+
std::forward<RewritePatternArgs>(args)...) {}
89+
};
90+
91+
} // namespace mlir
92+
93+
#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
//===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
10+
#define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
11+
12+
#include "mlir/IR/IRMapping.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/IR/Value.h"
15+
#include "mlir/Support/LLVM.h"
16+
#include "mlir/Support/LogicalResult.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
#include "llvm/Support/Casting.h"
19+
#include <iterator>
20+
#include <optional>
21+
#include <type_traits>
22+
#include <utility>
23+
24+
namespace mlir {
25+
26+
// If `h` is an homomorphism with respect to the source algebraic structure
27+
// induced by function `s` and the target algebraic structure induced by
28+
// function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
29+
// `h(t(x1, x2, ..., xn))`.
30+
//
31+
// Functors:
32+
// ---------
33+
// `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
34+
// Returns the operand relevant to the homomorphism.
35+
// There may be other operands that are not relevant.
36+
//
37+
// `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
38+
// Returns the result relevant to the homomorphism.
39+
//
40+
// `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
41+
// void` Populates into the vector the operands relevant to the homomorphism.
42+
//
43+
// `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
44+
// Return the result of the source algebraic operation relevant to the
45+
// homomorphism.
46+
//
47+
// `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
48+
// Return the result of the target algebraic operation relevant to the
49+
// homomorphism.
50+
//
51+
// `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
52+
// Check if the operation is an homomorphism of the required type.
53+
// Additionally if the optional is present checks if the operations are
54+
// compatible homomorphisms.
55+
//
56+
// `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
57+
// Check if the operation is an operation of the algebraic structure.
58+
//
59+
// `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
60+
// PatternRewriter &rewriter) -> Operation*`
61+
template <typename GetHomomorphismOpOperandFn,
62+
typename GetHomomorphismOpResultFn,
63+
typename GetSourceAlgebraicOpOperandsFn,
64+
typename GetSourceAlgebraicOpResultFn,
65+
typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
66+
typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
67+
struct HomomorphismSimplification : public RewritePattern {
68+
template <typename GetHomomorphismOpOperandFnArg,
69+
typename GetHomomorphismOpResultFnArg,
70+
typename GetSourceAlgebraicOpOperandsFnArg,
71+
typename GetSourceAlgebraicOpResultFnArg,
72+
typename GetTargetAlgebraicOpResultFnArg,
73+
typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
74+
typename CreateTargetAlgebraicOpFnArg,
75+
typename... RewritePatternArgs>
76+
HomomorphismSimplification(
77+
GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
78+
GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
79+
GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
80+
GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
81+
GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
82+
IsHomomorphismOpFnArg &&isHomomorphismOp,
83+
IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
84+
CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
85+
RewritePatternArgs &&...args)
86+
: RewritePattern(std::forward<RewritePatternArgs>(args)...),
87+
getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
88+
getHomomorphismOpOperand)),
89+
getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
90+
getHomomorphismOpResult)),
91+
getSourceAlgebraicOpOperands(
92+
std::forward<GetSourceAlgebraicOpOperandsFnArg>(
93+
getSourceAlgebraicOpOperands)),
94+
getSourceAlgebraicOpResult(
95+
std::forward<GetSourceAlgebraicOpResultFnArg>(
96+
getSourceAlgebraicOpResult)),
97+
getTargetAlgebraicOpResult(
98+
std::forward<GetTargetAlgebraicOpResultFnArg>(
99+
getTargetAlgebraicOpResult)),
100+
isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
101+
isSourceAlgebraicOp(
102+
std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
103+
createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
104+
createTargetAlgebraicOpFn)) {}
105+
106+
LogicalResult matchAndRewrite(Operation *op,
107+
PatternRewriter &rewriter) const override {
108+
SmallVector<OpOperand *> algebraicOpOperands;
109+
if (failed(matchOp(op, algebraicOpOperands))) {
110+
return failure();
111+
}
112+
return rewriteOp(op, algebraicOpOperands, rewriter);
113+
}
114+
115+
private:
116+
LogicalResult
117+
matchOp(Operation *sourceAlgebraicOp,
118+
SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
119+
if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
120+
return failure();
121+
}
122+
sourceAlgebraicOpOperands.clear();
123+
getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
124+
if (sourceAlgebraicOpOperands.empty()) {
125+
return failure();
126+
}
127+
128+
Operation *firstHomomorphismOp =
129+
sourceAlgebraicOpOperands.front()->get().getDefiningOp();
130+
if (!firstHomomorphismOp ||
131+
!isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
132+
return failure();
133+
}
134+
OpResult firstHomomorphismOpResult =
135+
getHomomorphismOpResult(firstHomomorphismOp);
136+
if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
137+
return failure();
138+
}
139+
140+
for (auto operand : sourceAlgebraicOpOperands) {
141+
Operation *homomorphismOp = operand->get().getDefiningOp();
142+
if (!homomorphismOp ||
143+
!isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
144+
return failure();
145+
}
146+
}
147+
return success();
148+
}
149+
150+
LogicalResult
151+
rewriteOp(Operation *sourceAlgebraicOp,
152+
const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
153+
PatternRewriter &rewriter) const {
154+
IRMapping irMapping;
155+
for (auto operand : sourceAlgebraicOpOperands) {
156+
Operation *homomorphismOp = operand->get().getDefiningOp();
157+
irMapping.map(operand->get(),
158+
getHomomorphismOpOperand(homomorphismOp)->get());
159+
}
160+
Operation *targetAlgebraicOp =
161+
createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
162+
163+
irMapping.clear();
164+
assert(!sourceAlgebraicOpOperands.empty());
165+
Operation *firstHomomorphismOp =
166+
sourceAlgebraicOpOperands[0]->get().getDefiningOp();
167+
irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
168+
getTargetAlgebraicOpResult(targetAlgebraicOp));
169+
Operation *newHomomorphismOp =
170+
rewriter.clone(*firstHomomorphismOp, irMapping);
171+
rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
172+
getHomomorphismOpResult(newHomomorphismOp));
173+
return success();
174+
}
175+
176+
GetHomomorphismOpOperandFn getHomomorphismOpOperand;
177+
GetHomomorphismOpResultFn getHomomorphismOpResult;
178+
GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
179+
GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
180+
GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
181+
IsHomomorphismOpFn isHomomorphismOp;
182+
IsSourceAlgebraicOpFn isSourceAlgebraicOp;
183+
CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
184+
};
185+
186+
} // namespace mlir
187+
188+
#endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_

0 commit comments

Comments
 (0)