Skip to content

Commit ff2720d

Browse files
authored
[mlir][mesh] Dedublicate iterator type and partial type information (#81920)
The two types duplicated mostly the same values. Here they are decomposed to carry orthogonal and complimentary information. Use `utils::IteratorType` instead of `mesh::IteratorType`. It now has only parallel and reduction values. Rename `Partial` to `ReductionKind`. Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.
1 parent d228191 commit ff2720d

File tree

12 files changed

+93
-108
lines changed

12 files changed

+93
-108
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16"
4141
// Mesh Enums.
4242
//===----------------------------------------------------------------------===//
4343

44-
def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
44+
def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
45+
"Reduction of an iterator/mesh dimension.", [
4546
I32EnumAttrCase<"Sum", 1, "sum">,
4647
I32EnumAttrCase<"Max", 2, "max">,
4748
I32EnumAttrCase<"Min", 3, "min">,
@@ -51,26 +52,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
5152
let cppNamespace = "::mlir::mesh";
5253
}
5354

54-
def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
55+
def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
5556
let assemblyFormat = "`<` $value `>`";
5657
}
5758

58-
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
59-
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
60-
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
61-
// is partial.
62-
def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
63-
I32EnumAttrCase<"Parallel", 1, "parallel">,
64-
I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
65-
I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
66-
I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
67-
I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
68-
I32EnumAttrCase<"Invalid", 100, "invalid">
69-
]> {
70-
let genSpecializedAttr = 0;
71-
let cppNamespace = "::mlir::mesh";
72-
}
73-
7459
//===----------------------------------------------------------------------===//
7560
// Mesh Attribute
7661
//===----------------------------------------------------------------------===//
@@ -83,14 +68,15 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
8368
"The mesh on which tensors are sharded.">:$mesh,
8469
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
8570
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
86-
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
71+
OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
8772
);
8873

8974
let summary = "Attribute that extends tensor type to distributed tensor type.";
9075

9176
let description = [{
92-
The MeshSharding attribute could be used in the encoding of a
93-
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
77+
The MeshSharding attribute is used in a `mesh.shard` operation.
78+
It specifies how a tensor is sharded and distributed across the process
79+
mesh.
9480

9581
1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
9682
mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -107,13 +93,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
10793

10894
4. `partial_type`: indicates the reduction type of the possible all-reduce
10995
op. It has 4 possible values:
110-
- `partial_sum`: denotes it's an all-reduce-sum
111-
- `partial_max`: denotes it's an all-reduce-max
112-
- `partial_min`: denotes it's an all-reduce-min
113-
- `partial_generic`: denotes that the all-reduce type is complex and cannot
114-
be represented merely by a simple sum, max, or min. The exact reduction
115-
computation may be derived from the semantics of the corresponding operation
116-
or from the reduction computation IR
96+
`generic`: is not an allowed value inside a shard attribute.
11797

11898
Example:
11999

@@ -149,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
149129
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
150130
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
151131
"ArrayRef<MeshAxis>": $partial_axes,
152-
"mesh::Partial": $partial_type), [{
132+
"mesh::ReductionKind": $partial_type), [{
153133
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
154134
split_axes, [&](ArrayRef<MeshAxis> array) {
155135
return MeshAxesAttr::get($_ctxt, array);
@@ -159,7 +139,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
159139
}]>,
160140
AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
161141
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
162-
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
142+
return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
163143
}]>
164144
];
165145

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1314
#include "mlir/IR/BuiltinTypeInterfaces.h"
1415
#include "mlir/IR/OpDefinition.h"
1516
#include "mlir/IR/PatternMatch.h"
@@ -38,18 +39,16 @@ using MeshAxesAttr = DenseI16ArrayAttr;
3839
namespace mlir {
3940
namespace mesh {
4041

41-
bool isReductionLoop(IteratorType iType);
42-
43-
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
42+
inline bool isReductionLoop(utils::IteratorType iType) {
43+
return iType == utils::IteratorType::reduction;
44+
}
4445

4546
template <typename T>
4647
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
4748
while (!array.empty() && array.back().empty())
4849
array.pop_back();
4950
}
5051

51-
Partial getPartialTypeFromReduction(IteratorType iType);
52-
5352
// Is the same tensor replicated on all processes.
5453
inline bool isFullReplication(MeshShardingAttr attr) {
5554
return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
330330
}];
331331
let arguments = !con(commonArgs, (ins
332332
AnyRankedTensor:$input,
333-
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
333+
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
334334
));
335335
let results = (outs
336336
AnyRankedTensor:$result
@@ -629,7 +629,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
629629
}];
630630
let arguments = !con(commonArgs, (ins
631631
AnyRankedTensor:$input,
632-
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
632+
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
633633
DenseI64ArrayAttr:$root,
634634
Variadic<Index>:$root_dynamic
635635
));
@@ -692,7 +692,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
692692
}];
693693
let arguments = !con(commonArgs, (ins
694694
AnyNon0RankedTensor:$input,
695-
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
695+
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
696696
IndexAttr:$scatter_axis
697697
));
698698
let results = (outs

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
1111

1212
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1314
#include "mlir/IR/Value.h"
1415
#include "mlir/Support/LLVM.h"
1516

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,39 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
2626
output tensors.
2727

2828
Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
29-
types are parallel, parallel, reduction-sum. This indicates that M and
29+
types are parallel, parallel, reduction. This indicates that M and
3030
N are traversed in parallel, while the K dimension is used for
3131
reduction.
32-
33-
Example 2: A softmax op's loop iterator types are parallel and
34-
invalid. The second dimension is considered as invalid because it is
35-
neither parallel nor any kind of reduction.
3632
}],
37-
/*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
33+
/*retType=*/"SmallVector<mlir::utils::IteratorType>",
3834
/*methodName=*/"getLoopIteratorTypes",
3935
/*args=*/(ins),
4036
/*methodBody=*/"",
4137
/*defaultImplementation=*/"return {};"
4238
>,
39+
InterfaceMethod<
40+
/*desc=*/[{
41+
Return the kind of all reduction loop iterators.
42+
The order is the same as the same as the result from
43+
`getLoopIteratorTypes`.
44+
45+
Example 1:
46+
iterator types = (parallel, reduction, parallel, reduction)
47+
|| ||
48+
reduction kinds = ( sum, max)
49+
50+
Example 2:
51+
A softmax op's loop iterator types are parallel and
52+
reduction.
53+
The reduction iterator will be of kind `generic`, since it is non of
54+
the available presets.
55+
}],
56+
/*retType=*/"SmallVector<ReductionKind>",
57+
/*methodName=*/"getReductionLoopIteratorKinds",
58+
/*args=*/(ins),
59+
/*methodBody=*/"",
60+
/*defaultImplementation=*/"return {};"
61+
>,
4362
InterfaceMethod<
4463
/*desc=*/[{
4564
Return the indexing maps attribute within the current operation.

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ template <typename Op>
3636
struct IndependentParallelIteratorDomainShardingInterface
3737
: public ShardingInterface::ExternalModel<
3838
IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
39-
SmallVector<IteratorType> getLoopIteratorTypes(Operation *operation) const {
40-
SmallVector<IteratorType> iterTypes;
39+
SmallVector<utils::IteratorType>
40+
getLoopIteratorTypes(Operation *operation) const {
41+
SmallVector<utils::IteratorType> iterTypes;
4142
for (Type t : operation->getOperandTypes()) {
4243
populateIteratorTypes(t, iterTypes);
4344
}
@@ -65,16 +66,17 @@ struct IndependentParallelIteratorDomainShardingInterface
6566
}
6667

6768
private:
68-
void populateIteratorTypes(Type t,
69-
SmallVector<IteratorType> &iterTypes) const {
69+
void
70+
populateIteratorTypes(Type t,
71+
SmallVector<utils::IteratorType> &iterTypes) const {
7072
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
7173
if (!rankedTensorType) {
7274
return;
7375
}
7476

7577
iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
7678
for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
77-
iterTypes.push_back(IteratorType::Parallel);
79+
iterTypes.push_back(utils::IteratorType::parallel);
7880
}
7981
}
8082
};
@@ -84,12 +86,13 @@ template <typename ElemwiseOp>
8486
struct ElementwiseShardingInterface
8587
: public ShardingInterface::ExternalModel<
8688
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
87-
SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
89+
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
8890
Value val = op->getOperand(0);
8991
auto type = val.getType().dyn_cast<RankedTensorType>();
9092
if (!type)
9193
return {};
92-
SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
94+
SmallVector<utils::IteratorType> types(type.getRank(),
95+
utils::IteratorType::parallel);
9396
return types;
9497
}
9598

mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace mesh {
3838
// the algebraic structure.
3939
template <typename AlgebraicOp>
4040
void populateAllReduceEndomorphismSimplificationPatterns(
41-
RewritePatternSet &patterns, Partial reduction) {
41+
RewritePatternSet &patterns, ReductionKind reduction) {
4242
auto getEndomorphismOpOperand = [](Operation *op) {
4343
auto allReduceOp = llvm::cast<AllReduceOp>(op);
4444
return &allReduceOp.getInputMutable();

mlir/lib/Dialect/Mesh/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRMeshDialect
99

1010
LINK_LIBS PUBLIC
1111
MLIRArithDialect
12+
MLIRDialectUtils
1213
MLIRIR
1314
MLIRSupport
1415
MLIRViewLikeInterface

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -148,33 +148,6 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
148148
return success();
149149
}
150150

151-
bool mesh::isReductionLoop(IteratorType iType) {
152-
return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
153-
}
154-
155-
bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
156-
return (partial == Partial::Generic &&
157-
iType == IteratorType::ReductionGeneric) ||
158-
(partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
159-
(partial == Partial::Max && iType == IteratorType::ReductionMax) ||
160-
(partial == Partial::Min && iType == IteratorType::ReductionMin);
161-
}
162-
163-
Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
164-
switch (iType) {
165-
case IteratorType::ReductionGeneric:
166-
return Partial::Generic;
167-
case IteratorType::ReductionSum:
168-
return Partial::Sum;
169-
case IteratorType::ReductionMax:
170-
return Partial::Max;
171-
case IteratorType::ReductionMin:
172-
return Partial::Min;
173-
default:
174-
llvm_unreachable("No corresponding partial type can be found");
175-
}
176-
}
177-
178151
template <typename InShape, typename MeshShape, typename SplitAxes,
179152
typename OutShape>
180153
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
@@ -278,7 +251,7 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
278251
LogicalResult
279252
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
280253
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
281-
ArrayRef<MeshAxis> partialAxes, Partial) {
254+
ArrayRef<MeshAxis> partialAxes, ReductionKind) {
282255
// TODO: At present mesh symbol ref is not verified. This is due to the
283256
// difficulty in fetching the corresponding symbol op based on an attribute.
284257

0 commit comments

Comments
 (0)