Skip to content

Commit 31fc0a1

Browse files
authored
[mlir][mesh] Refactoring code organization, tests and docs (#79606)
* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class. Reduces include clutter if you care only about the dialect and not the ops. * Expose functions `getMesh` and `collectiveProcessGroupSize`. There functions are useful for outside users of the dialect. * Remove unused code. * Remove examples and tests of mesh.shard attribute in tensor encoding. Per the decision that Spmdization would be performed on sharding annotations and there will be no tensors with sharding specified in the type. For more info see this RFC comment: https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
1 parent 6e6aa44 commit 31fc0a1

File tree

16 files changed

+123
-148
lines changed

16 files changed

+123
-148
lines changed

mlir/docs/Dialects/Mesh.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ These are the axes specified by `mesh_axes` attribute.
4646
For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
4747
an in-group device with `(i, j)`. Then for each group with index `g` on the
4848
second axis, the in-group device would be `(i, g, j)`.
49-
5049
### Purity
5150
Collectives that involve the whole device group to perform a single operation
5251
are pure. The exceptions are `send` and `recv`.
@@ -72,4 +71,4 @@ passes like dead code and common sub-expression elimination.
7271

7372
## Attributes
7473

75-
[include "Dialects/MeshAttributes.md"]
74+
[include "Dialects/MeshAttrs.md"]
Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1-
add_mlir_dialect(MeshOps mesh)
2-
add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
1+
add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc -dialect=mesh)
2+
add_mlir_doc(MeshOps MeshAttrs Dialects/ -gen-attrdef-doc -dialect=mesh)
3+
4+
set(LLVM_TARGET_DEFINITIONS MeshOps.td)
5+
mlir_tablegen(MeshDialect.cpp.inc -gen-dialect-defs -dialect=mesh)
6+
mlir_tablegen(MeshDialect.h.inc -gen-dialect-decls -dialect=mesh)
37

48
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
5-
mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
6-
mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
7-
add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
9+
mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
10+
mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
811

912
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
10-
mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
11-
mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
12-
add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
13+
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
14+
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
15+
16+
set(LLVM_TARGET_DEFINITIONS MeshOps.td)
17+
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
18+
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
19+
20+
add_public_tablegen_target(MLIRMeshIncGen)
21+
add_dependencies(mlir-headers MLIRMeshIncGen)

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
123123
// The tensor is fully replicated on @mesh0.
124124
// Currently, there must be at least one sub-array present in axes, even
125125
// if it's empty. Otherwise, a parsing error will occur.
126-
tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
126+
#mesh.shard<@mesh0, [[]]>
127127

128128
// The tensor is sharded on the first dimension along axis 0 of @mesh0
129-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
129+
#mesh.shard<@mesh0, [[0]]>
130130

131131
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
132132
// it is also a partial_sum along mesh axis 1.
133-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
133+
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>
134134

135135
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
136136
// it is also a partial_max along mesh axis 1.
137-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
137+
#mesh.shard<@mesh0, [[0]], partial = max[1]>
138138

139139
// Could be used in the attribute of mesh.shard op
140140
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
10+
#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
11+
12+
#include "mlir/IR/Dialect.h"
13+
14+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h.inc"
15+
16+
#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "mlir/IR/SymbolTable.h"
1616
#include "mlir/Interfaces/InferTypeOpInterface.h"
1717
#include "mlir/Interfaces/SideEffectInterfaces.h"
18-
#include <algorithm>
1918

2019
namespace mlir {
2120
namespace mesh {
@@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
2625
} // namespace mesh
2726
} // namespace mlir
2827

29-
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
30-
31-
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
28+
#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
3229

3330
#define GET_ATTRDEF_CLASSES
34-
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
31+
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
3532

3633
#define GET_OP_CLASSES
3734
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
5148

5249
Partial getPartialTypeFromReduction(IteratorType iType);
5350

51+
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
52+
SymbolTableCollection &symbolTableCollection) {
53+
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
54+
op, meshSymbol);
55+
}
56+
57+
// Get the corresponding mesh op using the standard attribute nomenclature.
58+
template <typename Op>
59+
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
60+
return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
61+
}
62+
63+
// Get the number of processes that participate in each group
64+
// induced by `meshAxes`.
65+
template <typename MeshAxesRange, typename MeshShapeRange>
66+
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
67+
MeshShapeRange &&meshShape) {
68+
int64_t res = 1;
69+
70+
for (MeshAxis axis : meshAxes) {
71+
auto axisSize = *(std::begin(meshShape) + axis);
72+
if (ShapedType::isDynamic(axisSize)) {
73+
return ShapedType::kDynamic;
74+
}
75+
res *= axisSize;
76+
}
77+
78+
return res;
79+
}
80+
5481
} // namespace mesh
5582
} // namespace mlir
5683

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
6161
// A device mesh with 2 axes, the number of devices along both axes
6262
// is unknown
6363
mesh.mesh @mesh3(shape = ?x?)
64-
65-
// Used in the mesh sharding attribute to extend the standard tensor to
66-
// distributed
67-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
6864
```
6965
}];
7066
let arguments = (ins

mlir/include/mlir/InitAllDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
5656
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
5757
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
58-
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
58+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
5959
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
6060
#include "mlir/Dialect/OpenACC/OpenACC.h"
6161
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"

mlir/lib/Dialect/Mesh/IR/CMakeLists.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ add_mlir_dialect_library(MLIRMeshDialect
55
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
66

77
DEPENDS
8-
MLIRMeshOpsAttrIncGen
9-
MLIRMeshOpsEnumsIncGen
10-
MLIRMeshOpsIncGen
8+
MLIRMeshIncGen
119

1210
LINK_LIBS PUBLIC
1311
MLIRArithDialect

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
10+
1011
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1113
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1214
#include "mlir/IR/Attributes.h"
1315
#include "mlir/IR/BuiltinAttributes.h"
@@ -41,25 +43,7 @@
4143
using namespace mlir;
4244
using namespace mlir::mesh;
4345

44-
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
45-
46-
template <typename It>
47-
static It canonicalizeSetAsArray(It begin, It end) {
48-
llvm::sort(begin, end);
49-
return std::unique(begin, end);
50-
}
51-
52-
template <typename R>
53-
static auto canonicalizeSetAsArray(R &&range) {
54-
return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
55-
}
56-
57-
template <typename T>
58-
static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
59-
auto newEnd = canonicalizeSetAsArray(vec);
60-
vec.resize(newEnd - vec.begin());
61-
return vec;
62-
}
46+
#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
6347

6448
namespace {
6549

@@ -101,7 +85,7 @@ void MeshDialect::initialize() {
10185
>();
10286
addAttributes<
10387
#define GET_ATTRDEF_LIST
104-
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
88+
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
10589
>();
10690
}
10791

@@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
11498
// Mesh utilities
11599
//===----------------------------------------------------------------------===//
116100

117-
static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
118-
SymbolTableCollection &symbolTable) {
119-
mesh::MeshOp mesh =
120-
symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
101+
static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
102+
FlatSymbolRefAttr meshSymbol,
103+
SymbolTableCollection &symbolTable) {
104+
mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
121105
if (!mesh) {
122106
return op->emitError() << "Undefined required mesh symbol \""
123107
<< meshSymbol.getValue() << "\".";
@@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
201185
if (rank <= 0)
202186
return emitOpError("rank of mesh is expected to be a positive integer");
203187

204-
if (getShape().size() > size_t(rank))
205-
return emitOpError(
206-
"rank of shape is not expected to be larger than rank of mesh");
207-
208188
for (int64_t dimSize : getShape()) {
209189
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
210190
return emitOpError("dimension size of a mesh is expected to be "
@@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {
220200

221201
LogicalResult
222202
MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
223-
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
203+
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
224204
if (failed(mesh)) {
225205
return failure();
226206
}
@@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
322302

323303
LogicalResult
324304
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
325-
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
305+
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
326306
if (failed(mesh)) {
327307
return failure();
328308
}
@@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
360340

361341
LogicalResult
362342
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
363-
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
343+
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
364344
if (failed(mesh)) {
365345
return failure();
366346
}
@@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
428408
template <typename Op>
429409
static FailureOr<MeshOp>
430410
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
431-
auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
411+
auto mesh =
412+
::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
432413
if (failed(mesh)) {
433414
return failure();
434415
}
@@ -450,21 +431,6 @@ static auto product(R &&range) {
450431
return product(adl_begin(range), adl_end(range));
451432
}
452433

453-
static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
454-
ArrayRef<int64_t> meshShape) {
455-
int64_t res = 1;
456-
457-
for (MeshAxis axis : meshAxes) {
458-
if (ShapedType::isDynamic(meshShape[axis])) {
459-
return ShapedType::kDynamic;
460-
}
461-
assert(size_t(axis) < meshShape.size());
462-
res *= meshShape[axis];
463-
}
464-
465-
return res;
466-
}
467-
468434
static LogicalResult verifyDimensionCompatibility(Location loc,
469435
int64_t expectedDimSize,
470436
int64_t resultDimSize,
@@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
495461
ShapedType operandType = operand.getType().cast<ShapedType>();
496462
ShapedType resultType = result.getType().cast<ShapedType>();
497463
auto deviceGroupSize =
498-
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
464+
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
499465
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
500466
auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
501467
auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
529495
}
530496

531497
auto deviceGroupSize =
532-
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
498+
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
533499
auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
534500
auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
535501
DimensionSize expectedResultConcatDimSize =
@@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
570536
}
571537

572538
auto deviceGroupSize =
573-
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
539+
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
574540
auto operandScatterDimSize =
575541
DimensionSize(operandType.getDimSize(scatterAxis));
576542
if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
@@ -846,6 +812,6 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
846812
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
847813

848814
#define GET_ATTRDEF_CLASSES
849-
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
815+
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
850816

851-
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.cpp.inc"
817+
#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
1010

1111
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1213
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1314
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
1415
#include "mlir/Pass/Pass.h"

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/Arith/IR/Arith.h"
1111
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1212
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1314
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1415
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1516
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
5455
return dim * shardCount;
5556
}
5657

57-
template <typename MeshShape, typename SplitAxes>
58-
int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
59-
int64_t res = 1;
60-
for (auto splitAxis : splitAxes) {
61-
int64_t meshDimSize = meshShape[splitAxis];
62-
if (ShapedType::isDynamic(meshDimSize)) {
63-
return ShapedType::kDynamic;
64-
}
65-
res *= meshDimSize;
66-
}
67-
return res;
68-
}
69-
7058
// Compute the shape for the tensor on each device in the mesh.
7159
// Example:
7260
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
@@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
7866
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
7967
llvm::adl_begin(outShape));
8068
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
81-
outShape[tensorAxis] =
82-
shardDimension(inShape[tensorAxis],
83-
shardCount(meshShape, innerSplitAxes.asArrayRef()));
69+
outShape[tensorAxis] = shardDimension(
70+
inShape[tensorAxis],
71+
collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
8472
}
8573
}
8674

mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
1010
#include "mlir/Dialect/Affine/IR/AffineOps.h"
11+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1112
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1213
#include "mlir/IR/BuiltinTypes.h"
1314
#include "mlir/IR/DialectRegistry.h"

0 commit comments

Comments
 (0)