Skip to content

[mlir] Implement Mesh's ShardingInterface for Linalg ops #82284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a common entry point for registering all external
// interface implementations to the linalg dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H

namespace mlir {
class DialectRegistry;

namespace linalg {
void registerAllDialectInterfaceImplementations(DialectRegistry &registry);
} // namespace linalg

} // namespace mlir

#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace linalg {
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
I32EnumAttrCase<"Product", 4, "product">,
// Arithmetic mean.
I32EnumAttrCase<"Average", 5, "average">,
I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
I32EnumAttrCase<"Generic", 100, "generic">
]> {
let genSpecializedAttr = 0;
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
let builders = [
OpBuilder<(ins "Value":$input, "StringRef":$mesh,
"ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
];
}

def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ class SymbolTableCollection;

namespace mesh {

// Retrieve the mesh axes corresponding to each operation loop iterator based
// on the provided shardings for the op's operands and results.
// Assumes that the indexingMaps are projected permutations.
ShardingArray getMeshAxisAssignmentForLoopIterators(
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps);

bool isAtLeastOneReductionIteratorSharded(
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);

// Get the set of mesh axes that correspond to reduction loop iterators.
SmallVector<MeshAxis> getReductionMeshAxes(
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);

// Inserts a clone of the operation that has all ranked tensor
// arguments/results sharded.
void spmdizeTriviallyShardableOperation(
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir {
class RewritePatternSet;
Expand All @@ -37,6 +38,11 @@ TypedValue<IndexType>
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder);

// Get process linear index along the given mesh axes.
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder);

} // namespace mesh
} // namespace mlir

Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ class Dialect {
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
}

// Declare the same interface for multiple types.
// Example:
// declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
template <typename InterfaceT, typename... ConcreteT>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important header. Can we add some documentation for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment.

void declarePromisedInterfaces() {
(declarePromisedInterface<ConcreteT, InterfaceT>(), ...);
}

/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error. `interfaceName` is an optional string that contains a
Expand Down
10 changes: 2 additions & 8 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
Expand Down Expand Up @@ -155,10 +152,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
cf::registerBufferizableOpInterfaceExternalModels(registry);
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerSubsetOpInterfaceExternalModels(registry);
linalg::registerTilingInterfaceExternalModels(registry);
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
linalg::registerAllDialectInterfaceImplementations(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRParser
MLIRShardingInterface
MLIRSideEffectInterfaces
MLIRSparseTensorDialect
MLIRSCFDialect
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
Expand Down Expand Up @@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() {
>(namedStructuredOpRegionBuilders);

addInterfaces<LinalgInlinerInterface>();

declarePromisedInterface<GenericOp, mesh::ShardingInterface>();
declarePromisedInterfaces<mesh::ShardingInterface,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();
}

LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===- AllInterfaces.cpp - ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"

#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"

void mlir::linalg::registerAllDialectInterfaceImplementations(
DialectRegistry &registry) {
registerBufferizableOpInterfaceExternalModels(registry);
registerMeshShardingInterfaceExternalModels(registry);
registerSubsetOpInterfaceExternalModels(registry);
registerTilingInterfaceExternalModels(registry);
registerValueBoundsOpInterfaceExternalModels(registry);
}
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLinalgTransforms
AllInterfaces.cpp
BubbleUpExtractSlice.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
Expand All @@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
Expand Down Expand Up @@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRIR
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRMeshDialect
MLIRMeshTransforms
MLIRLinalgDialect
MLIRLinalgUtils
MLIRSCFDialect
MLIRSCFTransforms
MLIRSCFUtils
MLIRPass
MLIRShardingInterface
MLIRSubsetOpInterface
MLIRSparseTensorDialect
MLIRTensorDialect
Expand Down
Loading