Skip to content

Commit b0d5b4d

Browse files
authored
[MLIR][Mesh] Add sharding propagation pass (#71261)
Add a pass that propagates sharding information throughout the graph. After this pass, each of the operations' operands and results is annotated with a mesh.shard operation. The pass is driven by a newly added ShardingInterface, and an implementation for element-wise and matmul ops in the TOSA dialect is provided.
1 parent 87f5e22 commit b0d5b4d

23 files changed

+1470
-5
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
add_subdirectory(Interfaces)
12
add_subdirectory(IR)
3+
add_subdirectory(Transforms)

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
4949
let cppNamespace = "::mlir::mesh";
5050
}
5151

52+
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
53+
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
54+
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
55+
// is partial.
56+
def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
57+
I32EnumAttrCase<"Parallel", 1, "parallel">,
58+
I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
59+
I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
60+
I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
61+
I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
62+
I32EnumAttrCase<"Invalid", 100, "invalid">
63+
]> {
64+
let genSpecializedAttr = 0;
65+
let cppNamespace = "::mlir::mesh";
66+
}
67+
5268
//===----------------------------------------------------------------------===//
5369
// Mesh Attribute
5470
//===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
122138
$partial_axes^ `]`)? `>`
123139
}];
124140

141+
let builders = [
142+
AttrBuilder<(ins "SymbolRefAttr":$cluster,
143+
"ArrayRef<SmallVector<int32_t>>":$split_axes,
144+
"ArrayRef<int32_t>": $partial_axes,
145+
"mesh::Partial": $partial_type), [{
146+
SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
147+
split_axes, [&](ArrayRef<int32_t> array) {
148+
return DenseI32ArrayAttr::get($_ctxt, array);
149+
});
150+
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
151+
partial_type);
152+
}]>,
153+
AttrBuilder<(ins "SymbolRefAttr":$cluster,
154+
"ArrayRef<SmallVector<int32_t>>":$split_axes), [{
155+
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
156+
}]>
157+
];
158+
125159
let genVerifyDecl = 1;
126160
}
127161

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,22 @@
2424
#define GET_OP_CLASSES
2525
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
2626

27+
namespace mlir {
28+
namespace mesh {
29+
30+
bool isReductionLoop(IteratorType iType);
31+
32+
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
33+
34+
template <typename T>
35+
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
36+
while (!array.empty() && array.back().empty())
37+
array.pop_back();
38+
}
39+
40+
Partial getPartialTypeFromReduction(IteratorType iType);
41+
42+
} // namespace mesh
43+
} // namespace mlir
44+
2745
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
2+
mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
3+
mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
4+
add_public_tablegen_target(MLIRShardingInterfaceIncGen)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
10+
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
11+
12+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/Support/LLVM.h"
14+
15+
namespace mlir {
16+
17+
class Operation;
18+
19+
namespace mesh {
20+
21+
using ShardingArray = SmallVector<SmallVector<int32_t>>;
22+
using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
23+
24+
struct ShardingOption {
25+
// An array of int array. The sub-array at the i-th position signifies the
26+
// mesh axes the i-th loop will be sharded on.
27+
ShardingArray shardingArray = {};
28+
SymbolRefAttr cluster = nullptr;
29+
// `empty` being true indicates that no sharding information can be inferred
30+
// at present. Note that it is different from the case where an operation is
31+
// not sharded.
32+
bool empty = false;
33+
ShardingOption() = default;
34+
ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
35+
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
36+
};
37+
38+
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
39+
// result and includes the 'annotate_for_users' information.
40+
FailureOr<std::pair<bool, MeshShardingAttr>>
41+
getMeshShardingAttr(OpResult result);
42+
43+
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
44+
// operand and includes the 'annotate_for_users' information.
45+
FailureOr<std::pair<bool, MeshShardingAttr>>
46+
getMeshShardingAttr(OpOperand &opOperand);
47+
48+
namespace detail {
49+
50+
FailureOr<ShardingOption>
51+
defaultGetShardingOption(Operation *op,
52+
ArrayRef<MeshShardingAttr> operandShardings,
53+
ArrayRef<MeshShardingAttr> resultShardings);
54+
55+
LogicalResult
56+
defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
57+
const ShardingOption &shardingOption);
58+
59+
} // namespace detail
60+
61+
} // namespace mesh
62+
63+
} // namespace mlir
64+
65+
/// Include the ODS generated interface header files.
66+
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
67+
68+
#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
10+
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def ShardingInterface : OpInterface<"ShardingInterface"> {
15+
let description = [{
16+
Interface for allowing operations to expose information needed to
17+
shard them.
18+
}];
19+
let cppNamespace = "::mlir::mesh";
20+
21+
let methods = [
22+
InterfaceMethod<
23+
/*desc=*/[{
24+
Returns a list of iterator types that describe the number of loops.
25+
The iterator types determine how the operation traverses its input and
26+
output tensors.
27+
28+
Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
29+
types are parallel, parallel, reduction-sum. This indicates that M and
30+
N are traversed in parallel, while the K dimension is used for
31+
reduction.
32+
33+
Example 2: A softmax op's loop iterator types are parallel and
34+
invalid. The second dimension is considered as invalid because it is
35+
neither parallel nor any kind of reduction.
36+
}],
37+
/*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
38+
/*methodName=*/"getLoopIteratorTypes",
39+
/*args=*/(ins),
40+
/*methodBody=*/"",
41+
/*defaultImplementation=*/"return {};"
42+
>,
43+
InterfaceMethod<
44+
/*desc=*/[{
45+
Return the indexing maps attribute within the current operation.
46+
Indexing maps determine how indices in the iteration space map to
47+
tensor indices. They are specified using `affine_map` in MLIR, which
48+
provides an affine transformation of indices.
49+
}],
50+
/*retTy=*/"SmallVector<AffineMap>",
51+
/*methodName=*/"getIndexingMaps",
52+
/*args=*/(ins),
53+
/*methodBody=*/"",
54+
/*defaultImplementation=*/"return {};"
55+
>,
56+
InterfaceMethod<
57+
/*desc=*/[{
58+
Given that certain operands or results of the operation may have
59+
sharding annotations, this method leverages this information to deduce
60+
how the operation should be sharded.
61+
}],
62+
/*retTy=*/"FailureOr<ShardingOption>",
63+
/*methodName=*/"getShardingOption",
64+
/*args=*/(ins
65+
"ArrayRef<MeshShardingAttr>": $operandShardings,
66+
"ArrayRef<MeshShardingAttr>": $resultShardings
67+
),
68+
/*methodBody=*/"",
69+
/*defaultImplementation=*/[{
70+
return detail::defaultGetShardingOption(
71+
$_op.getOperation(), operandShardings, resultShardings);
72+
}]
73+
>,
74+
InterfaceMethod<
75+
/*desc=*/[{
76+
Based on a given ShardingOption, this method adds `mesh.shard`
77+
operations for the operands and results that previously lacked
78+
sharding annotations.
79+
}],
80+
/*retTy=*/"LogicalResult",
81+
/*methodName=*/"addShardingAnnotations",
82+
/*args=*/(ins
83+
"OpBuilder &":$b,
84+
"const ShardingOption &":$shardingOption
85+
),
86+
/*methodBody=*/"",
87+
/*defaultImplementation=*/[{
88+
return detail::defaultAddShardingAnnotations(
89+
$_op.getOperation(), b, shardingOption);
90+
}]
91+
>
92+
];
93+
94+
let extraClassDeclaration = [{
95+
LogicalResult verifyShardingInterfaceImpl();
96+
97+
void printLoopTypesAndIndexingMaps(raw_ostream &os);
98+
}];
99+
}
100+
101+
102+
#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
3+
add_public_tablegen_target(MLIRMeshPassIncGen)
4+
add_dependencies(mlir-headers MLIRMeshPassIncGen)
5+
6+
add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
10+
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
11+
12+
#include "mlir/Pass/Pass.h"
13+
14+
namespace mlir {
15+
16+
namespace func {
17+
class FuncOp;
18+
}
19+
20+
namespace mesh {
21+
22+
//===----------------------------------------------------------------------===//
23+
// Passes
24+
//===----------------------------------------------------------------------===//
25+
26+
#define GEN_PASS_DECL
27+
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
28+
29+
//===----------------------------------------------------------------------===//
30+
// Registration
31+
//===----------------------------------------------------------------------===//
32+
33+
#define GEN_PASS_REGISTRATION
34+
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
35+
36+
} // namespace mesh
37+
} // namespace mlir
38+
39+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
10+
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
11+
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
12+
13+
include "mlir/Pass/PassBase.td"
14+
15+
//===----------------------------------------------------------------------===//
16+
// ShardingPropagation
17+
//===----------------------------------------------------------------------===//
18+
19+
def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
20+
let summary = "sharding propagation";
21+
let description = [{
22+
Propagates sharding information throughout the graph. After this pass, each
23+
of the operations' operands and results is annotated with a `mesh.shard`
24+
operation, and the operations themselves are added with sharding option
25+
attributes.
26+
}];
27+
let dependentDialects = [
28+
"mesh::MeshDialect"
29+
];
30+
}
31+
32+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
10+
#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace tosa {
17+
18+
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
19+
20+
} // namespace tosa
21+
} // namespace mlir
22+
23+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_

mlir/include/mlir/IR/AffineMap.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ class AffineMap {
104104
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
105105
MLIRContext *context);
106106

107+
/// Returns an affine map with `numDims` input dimensions and results
108+
/// specified by `targets`.
109+
///
110+
/// Examples:
111+
/// * getMultiDimMapWithTargets(3, [0, 2, 1])
112+
/// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
113+
/// * getMultiDimMapWithTargets(3, [2, 1])
114+
/// -> affine_map<(d0, d1, d2) -> (d2, d1)>
115+
static AffineMap getMultiDimMapWithTargets(unsigned numDims,
116+
ArrayRef<unsigned> targets,
117+
MLIRContext *context);
118+
107119
/// Returns a vector of AffineMaps; each with as many results as
108120
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
109121
/// symbols as the largest symbol in `exprs`.

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
8080
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
8181
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
82+
#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
8283
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
8384
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
8485
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
171172
tensor::registerSubsetOpInterfaceExternalModels(registry);
172173
tensor::registerTilingInterfaceExternalModels(registry);
173174
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
175+
tosa::registerShardingInterfaceExternalModels(registry);
174176
vector::registerBufferizableOpInterfaceExternalModels(registry);
175177
vector::registerSubsetOpInterfaceExternalModels(registry);
176178
NVVM::registerNVVMTargetInterfaceExternalModels(registry);

0 commit comments

Comments
 (0)