Skip to content

Commit 681ae09

Browse files
authored
[MLIR][mesh] moving shardinginterfaceimpl for tensor to tensor extension lib (#104913)
Follow-up to #102598 : as discussed, move tensor sharding implementation into separate tensor extension lib. @sogartar @yaochengji, could you take a look at this PR?
1 parent 9d36428 commit 681ae09

File tree

12 files changed

+101
-5
lines changed

12 files changed

+101
-5
lines changed

mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
1+
//===- MeshShardingExtensions.h - -----------------------------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- AllExtensions.h - All Tensor Extensions ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines a common entry point for registering all extensions to the
10+
// Tensor dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
15+
#define MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
16+
17+
namespace mlir {
18+
class DialectRegistry;
19+
20+
namespace tensor {
21+
/// Register all extensions of the Tensor dialect. This should generally only be
22+
/// used by tools, or other use cases that really do want *all* extensions of
23+
/// the dialect. All other cases should prefer to instead register the specific
24+
/// extensions they intend to take advantage of.
25+
void registerAllExtensions(DialectRegistry &registry);
26+
} // namespace tensor
27+
28+
} // namespace mlir
29+
30+
#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
10+
#define MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace tensor {
17+
18+
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
19+
20+
} // namespace tensor
21+
} // namespace mlir
22+
23+
#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_

mlir/include/mlir/InitAllDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
5959
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
6060
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
61-
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
6261
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
6362
#include "mlir/Dialect/OpenACC/OpenACC.h"
6463
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -182,7 +181,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
182181
tensor::registerBufferizableOpInterfaceExternalModels(registry);
183182
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
184183
tensor::registerInferTypeOpInterfaceExternalModels(registry);
185-
tensor::registerShardingInterfaceExternalModels(registry);
186184
tensor::registerSubsetOpInterfaceExternalModels(registry);
187185
tensor::registerTilingInterfaceExternalModels(registry);
188186
tensor::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
3535
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
3636
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
37+
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
3738
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
3839
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
3940
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
@@ -60,6 +61,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
6061
registerConvertComplexToLLVMInterface(registry);
6162
cf::registerConvertControlFlowToLLVMInterface(registry);
6263
func::registerAllExtensions(registry);
64+
tensor::registerAllExtensions(registry);
6365
registerConvertFuncToLLVMInterface(registry);
6466
index::registerConvertIndexToLLVMInterface(registry);
6567
registerConvertMathToLLVMInterface(registry);

mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_library(MLIRShardingInterface
22
ShardingInterface.cpp
3-
TensorShardingInterfaceImpl.cpp
43

54
ADDITIONAL_HEADER_DIRS
65
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh

mlir/lib/Dialect/Tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(Extensions)
12
add_subdirectory(IR)
23
add_subdirectory(Transforms)
34
add_subdirectory(TransformOps)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//===- AllExtensions.cpp - All Tensor Dialect Extensions ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
10+
#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
11+
12+
using namespace mlir;
13+
14+
void mlir::tensor::registerAllExtensions(DialectRegistry &registry) {
15+
registerShardingInterfaceExternalModels(registry);
16+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
set(LLVM_OPTIONAL_SOURCES
2+
AllExtensions.cpp
3+
MeshShardingExtensions.cpp
4+
)
5+
6+
add_mlir_extension_library(MLIRTensorMeshShardingExtensions
7+
MeshShardingExtensions.cpp
8+
9+
ADDITIONAL_HEADER_DIRS
10+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
11+
12+
LINK_LIBS PUBLIC
13+
MLIRTensorDialect
14+
MLIRIR
15+
MLIRShardingInterface
16+
)
17+
18+
add_mlir_extension_library(MLIRTensorAllExtensions
19+
AllExtensions.cpp
20+
21+
ADDITIONAL_HEADER_DIRS
22+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
23+
24+
LINK_LIBS PUBLIC
25+
MLIRTensorMeshShardingExtensions
26+
)

mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp renamed to mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
109
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
1110
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
11+
#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
1212
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1313
#include "mlir/IR/DialectRegistry.h"
1414
#include "llvm/Support/Debug.h"

mlir/tools/mlir-lsp-server/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ set(LIBS
4747
MLIRLspServerLib
4848
MLIRParser
4949
MLIRPass
50+
MLIRTensorAllExtensions
5051
MLIRTransforms
5152
MLIRTransformUtils
5253
MLIRSupport

0 commit comments

Comments
 (0)