Skip to content

Commit 25d68a6

Browse files
committed
[MLIR][Mesh] Add sharding propagation pass
1 parent e8fe4de commit 25d68a6

27 files changed

+1421
-1
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
add_subdirectory(Interfaces)
12
add_subdirectory(IR)
3+
add_subdirectory(Transforms)

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
4949
let cppNamespace = "::mlir::mesh";
5050
}
5151

52+
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
53+
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
54+
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
55+
// is partial.
56+
def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
57+
I32EnumAttrCase<"Parallel", 1, "parallel">,
58+
I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
59+
I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
60+
I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
61+
I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
62+
I32EnumAttrCase<"Invalid", 100, "invalid">
63+
]> {
64+
let genSpecializedAttr = 0;
65+
let cppNamespace = "::mlir::mesh";
66+
}
67+
5268
//===----------------------------------------------------------------------===//
5369
// Mesh Attribute
5470
//===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
122138
$partial_axes^ `]`)? `>`
123139
}];
124140

141+
let builders = [
142+
AttrBuilder<(ins "SymbolRefAttr":$cluster,
143+
"ArrayRef<SmallVector<int32_t>>":$split_axes,
144+
"ArrayRef<int32_t>": $partial_axes,
145+
"mesh::Partial": $partial_type), [{
146+
SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
147+
llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
148+
return DenseI32ArrayAttr::get($_ctxt, array);
149+
}));
150+
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
151+
partial_type);
152+
}]>,
153+
AttrBuilder<(ins "SymbolRefAttr":$cluster,
154+
"ArrayRef<SmallVector<int32_t>>":$split_axes), [{
155+
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
156+
}]>
157+
];
158+
125159
let genVerifyDecl = 1;
126160
}
127161

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,26 @@
2424
#define GET_OP_CLASSES
2525
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
2626

27+
namespace mlir {
28+
namespace mesh {
29+
30+
bool isReductionLoop(IteratorType iType);
31+
32+
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
33+
34+
template <typename T>
35+
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
36+
for (int64_t i = array.size() - 1; i >= 0; i--) {
37+
if (array[i].empty())
38+
array.pop_back();
39+
else
40+
break;
41+
}
42+
}
43+
44+
Partial getPartialTypeFromReduction(IteratorType iType);
45+
46+
} // namespace mesh
47+
} // namespace mlir
48+
2749
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
2+
mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
3+
mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
4+
add_public_tablegen_target(MLIRShardingInterfaceIncGen)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
10+
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
11+
12+
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13+
#include "mlir/Support/LLVM.h"
14+
15+
namespace mlir {
16+
17+
class Operation;
18+
19+
namespace mesh {
20+
21+
using ShardingArray = SmallVector<SmallVector<int32_t>>;
22+
using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
23+
24+
struct ShardingOption {
25+
// An array of int array. The sub-array at the i-th position signifies the
26+
// mesh axes the i-th loop will be sharded on.
27+
ShardingArray shardingArray;
28+
SymbolRefAttr cluster;
29+
// `empty` is true indicates that no sharding infomation can be inferred at
30+
// present. Note that it is different from that an operation is not sharded.
31+
bool empty = false;
32+
ShardingOption() = default;
33+
ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
34+
: shardingArray(shardingArray), cluster(cluster) {}
35+
};
36+
37+
constexpr StringRef getShardingArrayName() { return "sharding_array"; }
38+
39+
constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
40+
41+
namespace detail {
42+
43+
FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
44+
45+
LogicalResult
46+
defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
47+
const ShardingOption &shardingOption);
48+
49+
} // namespace detail
50+
51+
} // namespace mesh
52+
53+
} // namespace mlir
54+
55+
/// Include the ODS generated interface header files.
56+
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
57+
58+
#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
10+
#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def ShardingInterface : OpInterface<"ShardingInterface"> {
15+
let description = [{
16+
Interface for allowing operations to expose information needed to
17+
shard them.
18+
}];
19+
let cppNamespace = "::mlir::mesh";
20+
21+
let methods = [
22+
InterfaceMethod<
23+
/*desc=*/[{
24+
Returns a list of iterator types that describe the number of loops.
25+
}],
26+
/*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
27+
/*methodName=*/"getLoopIteratorTypes",
28+
/*args=*/(ins),
29+
/*methodBody=*/"",
30+
/*defaultImplementation=*/"return {};"
31+
>,
32+
InterfaceMethod<
33+
/*desc=*/[{
34+
Return the indexing maps attribute within the current operation.
35+
}],
36+
/*retTy=*/"SmallVector<AffineMap>",
37+
/*methodName=*/"getIndexingMaps",
38+
/*args=*/(ins),
39+
/*methodBody=*/"",
40+
/*defaultImplementation=*/"return {};"
41+
>,
42+
InterfaceMethod<
43+
/*desc=*/[{
44+
Given that certain operands or results of the operation may have
45+
sharding annotations, this method leverages this information to deduce
46+
how the operation should be sharded.
47+
}],
48+
/*retTy=*/"FailureOr<ShardingOption>",
49+
/*methodName=*/"getShardingOption",
50+
/*args=*/(ins
51+
"OpBuilder &":$b
52+
),
53+
/*methodBody=*/"",
54+
/*defaultImplementation=*/[{
55+
return detail::defaultGetShardingOption(
56+
$_op.getOperation(), b);
57+
}]
58+
>,
59+
InterfaceMethod<
60+
/*desc=*/[{
61+
Based on a given ShardingOption, this method adds `mesh.shard`
62+
operations for the operands and results that previously lacked
63+
sharding annotations.
64+
}],
65+
/*retTy=*/"LogicalResult",
66+
/*methodName=*/"addShardingAnnotations",
67+
/*args=*/(ins
68+
"OpBuilder &":$b,
69+
"const ShardingOption &":$shardingOption
70+
),
71+
/*methodBody=*/"",
72+
/*defaultImplementation=*/[{
73+
return detail::defaultAddShardingAnnotations(
74+
$_op.getOperation(), b, shardingOption);
75+
}]
76+
>
77+
];
78+
79+
let extraClassDeclaration = [{
80+
LogicalResult verifyShardingInterfaceImpl();
81+
82+
void printLoopTypesAndIndexingMaps(raw_ostream &os);
83+
}];
84+
}
85+
86+
87+
#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
3+
add_public_tablegen_target(MLIRMeshPassIncGen)
4+
add_dependencies(mlir-headers MLIRMeshPassIncGen)
5+
6+
add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
10+
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
11+
12+
#include "mlir/Pass/Pass.h"
13+
14+
namespace mlir {
15+
16+
namespace func {
17+
class FuncOp;
18+
}
19+
20+
namespace mesh {
21+
22+
//===----------------------------------------------------------------------===//
23+
// Passes
24+
//===----------------------------------------------------------------------===//
25+
26+
#define GEN_PASS_DECL
27+
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
28+
29+
std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
30+
31+
//===----------------------------------------------------------------------===//
32+
// Registration
33+
//===----------------------------------------------------------------------===//
34+
35+
#define GEN_PASS_REGISTRATION
36+
#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
37+
38+
} // namespace mesh
39+
} // namespace mlir
40+
41+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
10+
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
11+
#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
12+
13+
include "mlir/Pass/PassBase.td"
14+
15+
//===----------------------------------------------------------------------===//
16+
// ShardingPropagation
17+
//===----------------------------------------------------------------------===//
18+
19+
def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
20+
let summary = "sharding propagation";
21+
let description = [{
22+
Propagates sharding information throughout the graph. After this pass, each
23+
of the operations' operands and results is annotated with a `mesh.shard`
24+
operation, and the operations themselves are added with sharding option
25+
attributes.
26+
}];
27+
let constructor = "mlir::mesh::createShardingPropagationPass()";
28+
let dependentDialects = [
29+
"mesh::MeshDialect"
30+
];
31+
}
32+
33+
#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
10+
#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace tosa {
17+
18+
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
19+
20+
} // namespace tosa
21+
} // namespace mlir
22+
23+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
245245
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
246246
unsigned dropBack = 0);
247247

248+
/// Helper to return a vector of sub-vector of int64_t
249+
SmallVector<SmallVector<int32_t>> getArrayOfI32Array(ArrayAttr arrayAttr);
250+
248251
/// Compute linear index from provided strides and indices, assuming strided
249252
/// layout.
250253
/// Returns AffineExpr and list of values to apply to it, e.g.:

mlir/include/mlir/IR/AffineMap.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ class AffineMap {
101101
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
102102
MLIRContext *context);
103103

104+
/// Returns an affine map with `numDims` input dimensions and results
105+
/// specified by `targets`.
106+
///
107+
/// Examples:
108+
/// * getMultiDimMapWithTargets(3, [0, 2, 1])
109+
/// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
110+
/// * getMultiDimMapWithTargets(3, [2, 1])
111+
/// -> affine_map<(d0, d1, d2) -> (d2, d1)>
112+
static AffineMap getMultiDimMapWithTargets(unsigned numDims,
113+
ArrayRef<int64_t> targets,
114+
MLIRContext *context);
115+
104116
/// Returns a vector of AffineMaps; each with as many results as
105117
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
106118
/// symbols as the largest symbol in `exprs`.

mlir/include/mlir/IR/Builders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class Builder {
168168
ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
169169
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
170170
ArrayAttr getTypeArrayAttr(TypeRange values);
171+
ArrayAttr getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values);
171172

172173
// Affine expressions and affine maps.
173174
AffineExpr getAffineDimExpr(unsigned position);

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
8080
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
8181
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
82+
#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
8283
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
8384
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
8485
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -170,6 +171,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
170171
tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
171172
tensor::registerTilingInterfaceExternalModels(registry);
172173
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
174+
tosa::registerShardingInterfaceExternalModels(registry);
173175
vector::registerBufferizableOpInterfaceExternalModels(registry);
174176
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
175177
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);

0 commit comments

Comments
 (0)