Skip to content

Commit eee2ea9

Browse files
committed
Add sharding interface promise to the Linalg dialect
1 parent e26395e commit eee2ea9

File tree

4 files changed

+14
-0
lines changed

4 files changed

+14
-0
lines changed

mlir/include/mlir/IR/Dialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ class Dialect {
216216
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
217217
}
218218

219+
template <typename InterfaceT, typename... ConcreteT>
220+
void declarePromisedInterfaces() {
221+
(declarePromisedInterface<ConcreteT, InterfaceT>(), ...);
222+
}
223+
219224
/// Checks if the given interface, which is attempting to be used, is a
220225
/// promised interface of this dialect that has yet to be implemented. If so,
221226
/// emits a fatal error. `interfaceName` is an optional string that contains a

mlir/lib/Dialect/Linalg/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
2525
MLIRInferTypeOpInterface
2626
MLIRIR
2727
MLIRParser
28+
MLIRShardingInterface
2829
MLIRSideEffectInterfaces
2930
MLIRSparseTensorDialect
3031
MLIRSCFDialect

mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1717
#include "mlir/Dialect/Math/IR/Math.h"
1818
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
1920
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/Dialect.h"
@@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() {
118119
>(namedStructuredOpRegionBuilders);
119120

120121
addInterfaces<LinalgInlinerInterface>();
122+
123+
declarePromisedInterface<GenericOp, mesh::ShardingInterface>();
124+
declarePromisedInterfaces<mesh::ShardingInterface,
125+
#define GET_OP_LIST
126+
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
127+
>();
121128
}
122129

123130
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10840,6 +10840,7 @@ cc_library(
1084010840
":MemRefDialect",
1084110841
":Parser",
1084210842
":SCFDialect",
10843+
":MeshShardingInterface",
1084310844
":SideEffectInterfaces",
1084410845
":SparseTensorDialect",
1084510846
":Support",

0 commit comments

Comments
 (0)