Skip to content

[MLIR][mesh] moving shardinginterfaceimpl for tensor to tensor extension lib #104913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//===- MeshShardingExtensions.h - -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- AllExtensions.h - All Tensor Extensions ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a common entry point for registering all extensions to the
// Tensor dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
#define MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H

namespace mlir {
class DialectRegistry;

namespace tensor {
/// Register all extensions of the Tensor dialect. This should generally only be
/// used by tools, or other use cases that really do want *all* extensions of
/// the dialect. All other cases should prefer to instead register the specific
/// extensions they intend to take advantage of.
void registerAllExtensions(DialectRegistry &registry);
} // namespace tensor

} // namespace mlir

#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
#define MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_

namespace mlir {

class DialectRegistry;

namespace tensor {

void registerShardingInterfaceExternalModels(DialectRegistry &registry);

} // namespace tensor
} // namespace mlir

#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
2 changes: 0 additions & 2 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
Expand Down Expand Up @@ -182,7 +181,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerShardingInterfaceExternalModels(registry);
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
Expand All @@ -60,6 +61,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertComplexToLLVMInterface(registry);
cf::registerConvertControlFlowToLLVMInterface(registry);
func::registerAllExtensions(registry);
tensor::registerAllExtensions(registry);
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_mlir_library(MLIRShardingInterface
ShardingInterface.cpp
TensorShardingInterfaceImpl.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//===- AllExtensions.cpp - All Tensor Dialect Extensions ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"

using namespace mlir;

void mlir::tensor::registerAllExtensions(DialectRegistry &registry) {
registerShardingInterfaceExternalModels(registry);
}
Comment on lines +14 to +16
Copy link
Contributor

@sogartar sogartar Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR drives anther one that should add the rest of the extensions of the Tensor dialect here.

26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
MeshShardingExtensions.cpp
)

add_mlir_extension_library(MLIRTensorMeshShardingExtensions
MeshShardingExtensions.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions

LINK_LIBS PUBLIC
MLIRTensorDialect
MLIRIR
MLIRShardingInterface
)

add_mlir_extension_library(MLIRTensorAllExtensions
AllExtensions.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions

LINK_LIBS PUBLIC
MLIRTensorMeshShardingExtensions
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/Support/Debug.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/tools/mlir-lsp-server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ set(LIBS
MLIRLspServerLib
MLIRParser
MLIRPass
MLIRTensorAllExtensions
MLIRTransforms
MLIRTransformUtils
MLIRSupport
Expand Down
Loading