Skip to content

Commit fb582b6

Browse files
authored
[mlir] Implement Mesh's ShardingInterface for Linalg ops (#82284)
Allows linalg structured operations to be handled during spmdization and sharding propagation. There is only support for projected permutation indexing maps.
1 parent 9d3bf9b commit fb582b6

File tree

19 files changed

+754
-19
lines changed

19 files changed

+754
-19
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines a common entry point for registering all external
10+
// interface implementations to the linalg dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
15+
#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
16+
17+
namespace mlir {
18+
class DialectRegistry;
19+
20+
namespace linalg {
21+
void registerAllDialectInterfaceImplementations(DialectRegistry &registry);
22+
} // namespace linalg
23+
24+
} // namespace mlir
25+
26+
#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
10+
#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace linalg {
16+
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace linalg
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
4646
I32EnumAttrCase<"Sum", 1, "sum">,
4747
I32EnumAttrCase<"Max", 2, "max">,
4848
I32EnumAttrCase<"Min", 3, "min">,
49+
I32EnumAttrCase<"Product", 4, "product">,
50+
// Arithmetic mean.
51+
I32EnumAttrCase<"Average", 5, "average">,
52+
I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
53+
I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
54+
I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
4955
I32EnumAttrCase<"Generic", 100, "generic">
5056
]> {
5157
let genSpecializedAttr = 0;

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
353353
attr-dict `:` type($input) `->` type($result)
354354
}];
355355
let hasCanonicalizer = 1;
356+
let builders = [
357+
OpBuilder<(ins "Value":$input, "StringRef":$mesh,
358+
"ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
359+
];
356360
}
357361

358362
def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@ class SymbolTableCollection;
2222

2323
namespace mesh {
2424

25+
// Retrieve the mesh axes corresponding to each operation loop iterator based
26+
// on the provided shardings for the op's operands and results.
27+
// Assumes that the indexingMaps are projected permutations.
28+
ShardingArray getMeshAxisAssignmentForLoopIterators(
29+
ArrayRef<MeshShardingAttr> operandShardings,
30+
ArrayRef<MeshShardingAttr> resultShardings,
31+
ArrayRef<utils::IteratorType> loopIteratorTypes,
32+
ArrayRef<AffineMap> indexingMaps);
33+
34+
bool isAtLeastOneReductionIteratorSharded(
35+
ArrayRef<utils::IteratorType> loopIteratorTypes,
36+
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
37+
38+
// Get the set of mesh axes that correspond to reduction loop iterators.
39+
SmallVector<MeshAxis> getReductionMeshAxes(
40+
ArrayRef<utils::IteratorType> loopIteratorTypes,
41+
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
42+
2543
// Inserts a clone of the operation that has all ranked tensor
2644
// arguments/results sharded.
2745
void spmdizeTriviallyShardableOperation(

mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/BuiltinTypes.h"
1414
#include "mlir/IR/Value.h"
1515
#include "mlir/Support/LLVM.h"
16+
#include "llvm/ADT/ArrayRef.h"
1617

1718
namespace mlir {
1819
class RewritePatternSet;
@@ -37,6 +38,11 @@ TypedValue<IndexType>
3738
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
3839
ImplicitLocOpBuilder &builder);
3940

41+
// Get process linear index along the given mesh axes.
42+
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
43+
ArrayRef<MeshAxis> meshAxes,
44+
ImplicitLocOpBuilder &builder);
45+
4046
} // namespace mesh
4147
} // namespace mlir
4248

mlir/include/mlir/IR/Dialect.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ class Dialect {
216216
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
217217
}
218218

219+
// Declare the same interface for multiple types.
220+
// Example:
221+
// declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
222+
template <typename InterfaceT, typename... ConcreteT>
223+
void declarePromisedInterfaces() {
224+
(declarePromisedInterface<ConcreteT, InterfaceT>(), ...);
225+
}
226+
219227
/// Checks if the given interface, which is attempting to be used, is a
220228
/// promised interface of this dialect that has yet to be implemented. If so,
221229
/// emits a fatal error. `interfaceName` is an optional string that contains a

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@
4343
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
4444
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
4545
#include "mlir/Dialect/Linalg/IR/Linalg.h"
46-
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
47-
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
48-
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
49-
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
46+
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
5047
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
5148
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
5249
#include "mlir/Dialect/MPI/IR/MPI.h"
@@ -157,10 +154,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
157154
cf::registerBufferizableOpInterfaceExternalModels(registry);
158155
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
159156
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
160-
linalg::registerBufferizableOpInterfaceExternalModels(registry);
161-
linalg::registerSubsetOpInterfaceExternalModels(registry);
162-
linalg::registerTilingInterfaceExternalModels(registry);
163-
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
157+
linalg::registerAllDialectInterfaceImplementations(registry);
164158
memref::registerAllocationOpInterfaceExternalModels(registry);
165159
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
166160
memref::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Linalg/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
2525
MLIRInferTypeOpInterface
2626
MLIRIR
2727
MLIRParser
28+
MLIRShardingInterface
2829
MLIRSideEffectInterfaces
2930
MLIRSparseTensorDialect
3031
MLIRSCFDialect

mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1717
#include "mlir/Dialect/Math/IR/Math.h"
1818
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
1920
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/Dialect.h"
@@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() {
118119
>(namedStructuredOpRegionBuilders);
119120

120121
addInterfaces<LinalgInlinerInterface>();
122+
123+
declarePromisedInterface<GenericOp, mesh::ShardingInterface>();
124+
declarePromisedInterfaces<mesh::ShardingInterface,
125+
#define GET_OP_LIST
126+
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
127+
>();
121128
}
122129

123130
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- AllInterfaces.cpp - ------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
10+
11+
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
12+
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
13+
#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
14+
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
15+
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
16+
17+
void mlir::linalg::registerAllDialectInterfaceImplementations(
18+
DialectRegistry &registry) {
19+
registerBufferizableOpInterfaceExternalModels(registry);
20+
registerMeshShardingInterfaceExternalModels(registry);
21+
registerSubsetOpInterfaceExternalModels(registry);
22+
registerTilingInterfaceExternalModels(registry);
23+
registerValueBoundsOpInterfaceExternalModels(registry);
24+
}

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRLinalgTransforms
2+
AllInterfaces.cpp
23
BubbleUpExtractSlice.cpp
34
BufferizableOpInterfaceImpl.cpp
45
Bufferize.cpp
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2122
InlineScalarOperands.cpp
2223
Interchange.cpp
2324
Loops.cpp
25+
MeshShardingInterfaceImpl.cpp
2426
NamedOpConversions.cpp
2527
Padding.cpp
2628
Promotion.cpp
@@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms
6163
MLIRIR
6264
MLIRMemRefDialect
6365
MLIRMemRefTransforms
66+
MLIRMeshDialect
67+
MLIRMeshTransforms
6468
MLIRLinalgDialect
6569
MLIRLinalgUtils
6670
MLIRSCFDialect
6771
MLIRSCFTransforms
6872
MLIRSCFUtils
6973
MLIRPass
74+
MLIRShardingInterface
7075
MLIRSubsetOpInterface
7176
MLIRSparseTensorDialect
7277
MLIRTensorDialect

0 commit comments

Comments
 (0)