Skip to content

[mlir][mesh] Refactoring code organization, tests and docs #79606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/docs/Dialects/Mesh.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ These are the axes specified by `mesh_axes` attribute.
For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
an in-group device with `(i, j)`. Then for each group with index `g` on the
second axis, the in-group device would be `(i, g, j)`.

### Purity
Collectives that involve the whole device group to perform a single operation
are pure. The exceptions are `send` and `recv`.
Expand All @@ -72,4 +71,4 @@ passes like dead code and common sub-expression elimination.

## Attributes

[include "Dialects/MeshAttributes.md"]
[include "Dialects/MeshAttrs.md"]
25 changes: 17 additions & 8 deletions mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
add_mlir_dialect(MeshOps mesh)
add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc -dialect=mesh)
add_mlir_doc(MeshOps MeshAttrs Dialects/ -gen-attrdef-doc -dialect=mesh)

set(LLVM_TARGET_DEFINITIONS MeshOps.td)
mlir_tablegen(MeshDialect.cpp.inc -gen-dialect-defs -dialect=mesh)
mlir_tablegen(MeshDialect.h.inc -gen-dialect-decls -dialect=mesh)

set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)

set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)

set(LLVM_TARGET_DEFINITIONS MeshOps.td)
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)

add_public_tablegen_target(MLIRMeshIncGen)
add_dependencies(mlir-headers MLIRMeshIncGen)
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
#mesh.shard<@mesh0, [[]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
#mesh.shard<@mesh0, [[0]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
#mesh.shard<@mesh0, [[0]], partial = max[1]>

// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H

#include "mlir/IR/Dialect.h"

#include "mlir/Dialect/Mesh/IR/MeshDialect.h.inc"

#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
37 changes: 32 additions & 5 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <algorithm>

namespace mlir {
namespace mesh {
Expand All @@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
} // namespace mesh
} // namespace mlir

#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"

#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
Expand All @@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {

Partial getPartialTypeFromReduction(IteratorType iType);

inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) {
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
op, meshSymbol);
}

// Get the corresponding mesh op using the standard attribute nomenclature.
template <typename Op>
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
}

// Get the number of processes that participate in each group
// induced by `meshAxes`.
template <typename MeshAxesRange, typename MeshShapeRange>
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
MeshShapeRange &&meshShape) {
int64_t res = 1;

for (MeshAxis axis : meshAxes) {
auto axisSize = *(std::begin(meshShape) + axis);
if (ShapedType::isDynamic(axisSize)) {
return ShapedType::kDynamic;
}
res *= axisSize;
}

return res;
}

} // namespace mesh
} // namespace mlir

Expand Down
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
// A device mesh with 2 axes, the number of devices along both axes
// is unknown
mesh.mesh @mesh3(shape = ?x?)

// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
```
}];
let arguments = (ins
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ add_mlir_dialect_library(MLIRMeshDialect
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh

DEPENDS
MLIRMeshOpsAttrIncGen
MLIRMeshOpsEnumsIncGen
MLIRMeshOpsIncGen
MLIRMeshIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
Expand Down
70 changes: 18 additions & 52 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/IR/MeshOps.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -41,25 +43,7 @@
using namespace mlir;
using namespace mlir::mesh;

#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"

template <typename It>
static It canonicalizeSetAsArray(It begin, It end) {
llvm::sort(begin, end);
return std::unique(begin, end);
}

template <typename R>
static auto canonicalizeSetAsArray(R &&range) {
return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
}

template <typename T>
static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
auto newEnd = canonicalizeSetAsArray(vec);
vec.resize(newEnd - vec.begin());
return vec;
}
#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"

namespace {

Expand Down Expand Up @@ -101,7 +85,7 @@ void MeshDialect::initialize() {
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
>();
}

Expand All @@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Mesh utilities
//===----------------------------------------------------------------------===//

static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
mesh::MeshOp mesh =
symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
if (!mesh) {
return op->emitError() << "Undefined required mesh symbol \""
<< meshSymbol.getValue() << "\".";
Expand Down Expand Up @@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
if (rank <= 0)
return emitOpError("rank of mesh is expected to be a positive integer");

if (getShape().size() > size_t(rank))
return emitOpError(
"rank of shape is not expected to be larger than rank of mesh");

for (int64_t dimSize : getShape()) {
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
return emitOpError("dimension size of a mesh is expected to be "
Expand All @@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {

LogicalResult
MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
}
Expand Down Expand Up @@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {

LogicalResult
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
}
Expand Down Expand Up @@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,

LogicalResult
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
}
Expand Down Expand Up @@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
template <typename Op>
static FailureOr<MeshOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
auto mesh =
::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
}
Expand All @@ -450,21 +431,6 @@ static auto product(R &&range) {
return product(adl_begin(range), adl_end(range));
}

static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
ArrayRef<int64_t> meshShape) {
int64_t res = 1;

for (MeshAxis axis : meshAxes) {
if (ShapedType::isDynamic(meshShape[axis])) {
return ShapedType::kDynamic;
}
assert(size_t(axis) < meshShape.size());
res *= meshShape[axis];
}

return res;
}

static LogicalResult verifyDimensionCompatibility(Location loc,
int64_t expectedDimSize,
int64_t resultDimSize,
Expand Down Expand Up @@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType resultType = result.getType().cast<ShapedType>();
auto deviceGroupSize =
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
Expand Down Expand Up @@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
}

auto deviceGroupSize =
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
DimensionSize expectedResultConcatDimSize =
Expand Down Expand Up @@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
}

auto deviceGroupSize =
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
auto operandScatterDimSize =
DimensionSize(operandType.getDimSize(scatterAxis));
if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
Expand Down Expand Up @@ -846,6 +812,6 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"

#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Mesh/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Pass/Pass.h"
Expand Down
20 changes: 4 additions & 16 deletions mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
Expand Down Expand Up @@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
return dim * shardCount;
}

template <typename MeshShape, typename SplitAxes>
int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
int64_t res = 1;
for (auto splitAxis : splitAxes) {
int64_t meshDimSize = meshShape[splitAxis];
if (ShapedType::isDynamic(meshDimSize)) {
return ShapedType::kDynamic;
}
res *= meshDimSize;
}
return res;
}

// Compute the shape for the tensor on each device in the mesh.
// Example:
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
Expand All @@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
outShape[tensorAxis] =
shardDimension(inShape[tensorAxis],
shardCount(meshShape, innerSplitAxes.asArrayRef()));
outShape[tensorAxis] = shardDimension(
inShape[tensorAxis],
collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
}
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
Expand Down
Loading