Skip to content

Commit 121bdc3

Browse files
committed
[mlir][mesh] Refactoring code organization, tests and docs
* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class. Reduces include clutter if you care only about the dialect and not the ops. * Expose functions `getMesh` and `collectiveProcessGroupSize`. There functions are useful for outside users of the dialect. * Remove unused code. * Remove examples and tests of mesh.shard attribute in tensor encoding. Per the decision that Spmdization would be performed on sharding annotations and there will be no tensors with sharding specified in the type. For more info see this RFC comment: https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
1 parent 1f930cf commit 121bdc3

File tree

15 files changed

+108
-139
lines changed

15 files changed

+108
-139
lines changed

mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ add_mlir_dialect(MeshOps mesh)
22
add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
33

44
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
5-
mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
6-
mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
7-
add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
5+
mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
6+
mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
7+
add_public_tablegen_target(MLIRMeshAttrIncGen)
88

99
set(LLVM_TARGET_DEFINITIONS MeshBase.td)
10-
mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
11-
mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
12-
add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
10+
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
11+
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
12+
add_public_tablegen_target(MLIRMeshEnumsIncGen)

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
123123
// The tensor is fully replicated on @mesh0.
124124
// Currently, there must be at least one sub-array present in axes, even
125125
// if it's empty. Otherwise, a parsing error will occur.
126-
tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
126+
#mesh.shard<@mesh0, [[]]>
127127

128128
// The tensor is sharded on the first dimension along axis 0 of @mesh0
129-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
129+
#mesh.shard<@mesh0, [[0]]>
130130

131131
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
132132
// it is also a partial_sum along mesh axis 1.
133-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
133+
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>
134134

135135
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
136136
// it is also a partial_max along mesh axis 1.
137-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
137+
#mesh.shard<@mesh0, [[0]], partial = max[1]>
138138

139139
// Could be used in the attribute of mesh.shard op
140140
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
10+
#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
11+
12+
#include "mlir/IR/Dialect.h"
13+
14+
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
15+
16+
#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "mlir/IR/SymbolTable.h"
1616
#include "mlir/Interfaces/InferTypeOpInterface.h"
1717
#include "mlir/Interfaces/SideEffectInterfaces.h"
18-
#include <algorithm>
1918

2019
namespace mlir {
2120
namespace mesh {
@@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
2625
} // namespace mesh
2726
} // namespace mlir
2827

29-
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
30-
31-
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
28+
#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
3229

3330
#define GET_ATTRDEF_CLASSES
34-
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
31+
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
3532

3633
#define GET_OP_CLASSES
3734
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
5148

5249
Partial getPartialTypeFromReduction(IteratorType iType);
5350

51+
inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
52+
SymbolTableCollection &symbolTableCollection) {
53+
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
54+
op, meshSymbol);
55+
}
56+
57+
// Get the corresponding mesh op using the standard attribute nomenclature.
58+
template <typename Op>
59+
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
60+
return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
61+
}
62+
63+
// Get the number of processes that participate in each group
64+
// induced by `meshAxes`.
65+
template <typename MeshAxesRange, typename MeshShapeRange>
66+
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
67+
MeshShapeRange &&meshShape) {
68+
int64_t res = 1;
69+
70+
for (MeshAxis axis : meshAxes) {
71+
auto axisSize = *(std::begin(meshShape) + axis);
72+
if (ShapedType::isDynamic(axisSize)) {
73+
return ShapedType::kDynamic;
74+
}
75+
res *= axisSize;
76+
}
77+
78+
return res;
79+
}
80+
5481
} // namespace mesh
5582
} // namespace mlir
5683

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
6161
// A device mesh with 2 axes, the number of devices along both axes
6262
// is unknown
6363
mesh.mesh @mesh3(shape = ?x?)
64-
65-
// Used in the mesh sharding attribute to extend the standard tensor to
66-
// distributed
67-
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
6864
```
6965
}];
7066
let arguments = (ins

mlir/include/mlir/InitAllDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
5555
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
5656
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
57-
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
57+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
5858
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
5959
#include "mlir/Dialect/OpenACC/OpenACC.h"
6060
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"

mlir/lib/Dialect/Mesh/IR/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ add_mlir_dialect_library(MLIRMeshDialect
55
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
66

77
DEPENDS
8-
MLIRMeshOpsAttrIncGen
9-
MLIRMeshOpsEnumsIncGen
8+
MLIRMeshAttrIncGen
9+
MLIRMeshEnumsIncGen
1010
MLIRMeshOpsIncGen
1111

1212
LINK_LIBS PUBLIC

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 14 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
10+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
11+
1012
#include "mlir/Dialect/Arith/IR/Arith.h"
1113
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1214
#include "mlir/IR/Attributes.h"
@@ -43,24 +45,6 @@ using namespace mlir::mesh;
4345

4446
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
4547

46-
template <typename It>
47-
static It canonicalizeSetAsArray(It begin, It end) {
48-
llvm::sort(begin, end);
49-
return std::unique(begin, end);
50-
}
51-
52-
template <typename R>
53-
static auto canonicalizeSetAsArray(R &&range) {
54-
return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
55-
}
56-
57-
template <typename T>
58-
static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
59-
auto newEnd = canonicalizeSetAsArray(vec);
60-
vec.resize(newEnd - vec.begin());
61-
return vec;
62-
}
63-
6448
namespace {
6549

6650
struct DimensionSize {
@@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
11498
// Mesh utilities
11599
//===----------------------------------------------------------------------===//
116100

117-
static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
118-
SymbolTableCollection &symbolTable) {
119-
mesh::MeshOp mesh =
120-
symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
101+
static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
102+
FlatSymbolRefAttr meshSymbol,
103+
SymbolTableCollection &symbolTable) {
104+
mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
121105
if (!mesh) {
122106
return op->emitError() << "Undefined required mesh symbol \""
123107
<< meshSymbol.getValue() << "\".";
@@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
201185
if (rank <= 0)
202186
return emitOpError("rank of mesh is expected to be a positive integer");
203187

204-
if (getShape().size() > size_t(rank))
205-
return emitOpError(
206-
"rank of shape is not expected to be larger than rank of mesh");
207-
208188
for (int64_t dimSize : getShape()) {
209189
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
210190
return emitOpError("dimension size of a mesh is expected to be "
@@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {
220200

221201
LogicalResult
222202
MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
223-
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
203+
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
224204
if (failed(mesh)) {
225205
return failure();
226206
}
@@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
322302

323303
LogicalResult
324304
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
325-
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
305+
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
326306
if (failed(mesh)) {
327307
return failure();
328308
}
@@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
360340

361341
LogicalResult
362342
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
363-
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
343+
auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
364344
if (failed(mesh)) {
365345
return failure();
366346
}
@@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
428408
template <typename Op>
429409
static FailureOr<MeshOp>
430410
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
431-
auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
411+
auto mesh =
412+
::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
432413
if (failed(mesh)) {
433414
return failure();
434415
}
@@ -450,21 +431,6 @@ static auto product(R &&range) {
450431
return product(adl_begin(range), adl_end(range));
451432
}
452433

453-
static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
454-
ArrayRef<int64_t> meshShape) {
455-
int64_t res = 1;
456-
457-
for (MeshAxis axis : meshAxes) {
458-
if (ShapedType::isDynamic(meshShape[axis])) {
459-
return ShapedType::kDynamic;
460-
}
461-
assert(size_t(axis) < meshShape.size());
462-
res *= meshShape[axis];
463-
}
464-
465-
return res;
466-
}
467-
468434
static LogicalResult verifyDimensionCompatibility(Location loc,
469435
int64_t expectedDimSize,
470436
int64_t resultDimSize,
@@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
495461
ShapedType operandType = operand.getType().cast<ShapedType>();
496462
ShapedType resultType = result.getType().cast<ShapedType>();
497463
auto deviceGroupSize =
498-
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
464+
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
499465
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
500466
auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
501467
auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
529495
}
530496

531497
auto deviceGroupSize =
532-
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
498+
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
533499
auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
534500
auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
535501
DimensionSize expectedResultConcatDimSize =
@@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
570536
}
571537

572538
auto deviceGroupSize =
573-
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
539+
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
574540
auto operandScatterDimSize =
575541
DimensionSize(operandType.getDimSize(scatterAxis));
576542
if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
1010

1111
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1213
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1314
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
1415
#include "mlir/Pass/Pass.h"

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/Arith/IR/Arith.h"
1111
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1212
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1314
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1415
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1516
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
5455
return dim * shardCount;
5556
}
5657

57-
template <typename MeshShape, typename SplitAxes>
58-
int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
59-
int64_t res = 1;
60-
for (auto splitAxis : splitAxes) {
61-
int64_t meshDimSize = meshShape[splitAxis];
62-
if (ShapedType::isDynamic(meshDimSize)) {
63-
return ShapedType::kDynamic;
64-
}
65-
res *= meshDimSize;
66-
}
67-
return res;
68-
}
69-
7058
// Compute the shape for the tensor on each device in the mesh.
7159
// Example:
7260
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
@@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
7866
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
7967
llvm::adl_begin(outShape));
8068
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
81-
outShape[tensorAxis] =
82-
shardDimension(inShape[tensorAxis],
83-
shardCount(meshShape, innerSplitAxes.asArrayRef()));
69+
outShape[tensorAxis] = shardDimension(
70+
inShape[tensorAxis],
71+
collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
8472
}
8573
}
8674

mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
1010
#include "mlir/Dialect/Affine/IR/AffineOps.h"
11+
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1112
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1213
#include "mlir/IR/BuiltinTypes.h"
1314
#include "mlir/IR/DialectRegistry.h"

0 commit comments

Comments
 (0)