Skip to content

Commit e26395e

Browse files
committed
[mlir] Implement Mesh's ShardingInterface for Linalg ops
Allows linalg structured operations to be handled during spmdization and sharding propagation. There is only support for projected permutation indexing maps.
1 parent 552da24 commit e26395e

File tree

16 files changed

+714
-16
lines changed

16 files changed

+714
-16
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines a common entry point for registering all external
10+
// interface implementations to the linalg dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
15+
#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
16+
17+
namespace mlir {
18+
class DialectRegistry;
19+
20+
namespace linalg {
21+
void registerAllDialectInterfaceImplementations(DialectRegistry &registry);
22+
} // namespace linalg
23+
24+
} // namespace mlir
25+
26+
#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
10+
#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace linalg {
16+
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace linalg
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
4646
I32EnumAttrCase<"Sum", 1, "sum">,
4747
I32EnumAttrCase<"Max", 2, "max">,
4848
I32EnumAttrCase<"Min", 3, "min">,
49+
I32EnumAttrCase<"Product", 4, "product">,
50+
// Arithmetic mean.
51+
I32EnumAttrCase<"Average", 5, "average">,
52+
I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
53+
I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
54+
I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
4955
I32EnumAttrCase<"Generic", 100, "generic">
5056
]> {
5157
let genSpecializedAttr = 0;

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
353353
attr-dict `:` type($input) `->` type($result)
354354
}];
355355
let hasCanonicalizer = 1;
356+
let builders = [
357+
OpBuilder<(ins "Value":$input, "StringRef":$mesh,
358+
"ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
359+
];
356360
}
357361

358362
def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@ class SymbolTableCollection;
2222

2323
namespace mesh {
2424

25+
// Retrieve the mesh axes corresponding to each operation loop iterator based
26+
// on the provided shardings for the op's operands and results.
27+
// Assumes that the indexingMaps are projected permutations.
28+
ShardingArray getMeshAxisAssignmentForLoopIterators(
29+
ArrayRef<MeshShardingAttr> operandShardings,
30+
ArrayRef<MeshShardingAttr> resultShardings,
31+
ArrayRef<utils::IteratorType> loopIteratorTypes,
32+
ArrayRef<AffineMap> indexingMaps);
33+
34+
bool isAtLeastOneReductionIteratorSharded(
35+
ArrayRef<utils::IteratorType> loopIteratorTypes,
36+
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
37+
38+
// Get the set of mesh axes that correspond to reduction loop iterators.
39+
SmallVector<MeshAxis> getReductionMeshAxes(
40+
ArrayRef<utils::IteratorType> loopIteratorTypes,
41+
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
42+
2543
// Inserts a clone of the operation that has all ranked tensor
2644
// arguments/results sharded.
2745
void spmdizeTriviallyShardableOperation(

mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/BuiltinTypes.h"
1414
#include "mlir/IR/Value.h"
1515
#include "mlir/Support/LLVM.h"
16+
#include "llvm/ADT/ArrayRef.h"
1617

1718
namespace mlir {
1819
class RewritePatternSet;
@@ -37,6 +38,11 @@ TypedValue<IndexType>
3738
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
3839
ImplicitLocOpBuilder &builder);
3940

41+
// Get process linear index along the given mesh axes.
42+
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
43+
ArrayRef<MeshAxis> meshAxes,
44+
ImplicitLocOpBuilder &builder);
45+
4046
} // namespace mesh
4147
} // namespace mlir
4248

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@
4343
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
4444
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
4545
#include "mlir/Dialect/Linalg/IR/Linalg.h"
46-
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
47-
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
48-
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
49-
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
46+
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
5047
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
5148
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
5249
#include "mlir/Dialect/MPI/IR/MPI.h"
@@ -155,10 +152,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
155152
cf::registerBufferizableOpInterfaceExternalModels(registry);
156153
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
157154
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
158-
linalg::registerBufferizableOpInterfaceExternalModels(registry);
159-
linalg::registerSubsetOpInterfaceExternalModels(registry);
160-
linalg::registerTilingInterfaceExternalModels(registry);
161-
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
155+
linalg::registerAllDialectInterfaceImplementations(registry);
162156
memref::registerAllocationOpInterfaceExternalModels(registry);
163157
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
164158
memref::registerValueBoundsOpInterfaceExternalModels(registry);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- AllInterfaces.cpp - --------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
10+
11+
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
12+
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
13+
#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
14+
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
15+
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
16+
17+
void mlir::linalg::registerAllDialectInterfaceImplementations(
18+
DialectRegistry &registry) {
19+
registerBufferizableOpInterfaceExternalModels(registry);
20+
registerMeshShardingInterfaceExternalModels(registry);
21+
registerSubsetOpInterfaceExternalModels(registry);
22+
registerTilingInterfaceExternalModels(registry);
23+
registerValueBoundsOpInterfaceExternalModels(registry);
24+
}

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRLinalgTransforms
2+
AllInterfaces.cpp
23
BubbleUpExtractSlice.cpp
34
BufferizableOpInterfaceImpl.cpp
45
Bufferize.cpp
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2122
InlineScalarOperands.cpp
2223
Interchange.cpp
2324
Loops.cpp
25+
MeshShardingInterfaceImpl.cpp
2426
NamedOpConversions.cpp
2527
Padding.cpp
2628
Promotion.cpp
@@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms
6163
MLIRIR
6264
MLIRMemRefDialect
6365
MLIRMemRefTransforms
66+
MLIRMeshDialect
67+
MLIRMeshTransforms
6468
MLIRLinalgDialect
6569
MLIRLinalgUtils
6670
MLIRSCFDialect
6771
MLIRSCFTransforms
6872
MLIRSCFUtils
6973
MLIRPass
74+
MLIRShardingInterface
7075
MLIRSubsetOpInterface
7176
MLIRSparseTensorDialect
7277
MLIRTensorDialect

0 commit comments

Comments
 (0)