Skip to content

Commit 944e031

Browse files
authored
[mlir][mesh] Use tensor shape notation for the shape of a cluster (#73826)
Examle: substitute mesh.cluster @Mesh0(rank = 2, dim_sizes = [0, 4]) with mesh.cluster @Mesh0(rank = 2, dim_sizes = ?x4) Same as tensor/memref shapes. The only difference is for 0-rank shapes. With tensors you would have something like `tensor<f32>`. Here to avoid matching an empty string a 0-rank shape is denoted by `[]`.
1 parent 02379d1 commit 944e031

File tree

11 files changed

+194
-71
lines changed

11 files changed

+194
-71
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,27 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
4040
determine the layout and the addressing space of the computation distributed
4141
across the mesh.
4242

43-
3. `dim_sizes`: This attribute represents the device assignment along the
44-
axes of the cluster. Each integer in the array corresponds to the number of
45-
devices along a specific axis. If an integer value is 0, it implies that the
46-
number of devices along that axis is unknown. This flexibility allows for
47-
dynamic device assignment or configurations where the exact number of
48-
devices might not be determined during compile time.
43+
3. `dim_sizes`: This attribute represents the shape of the device cluster.
44+
It uses the same notation as a tensor shape. Also allowing for dynamic
45+
dimensions.
46+
This flexibility allows for dynamic device assignment or configurations
47+
where the exact number of devices might not be determined during compile
48+
time.
49+
For example `2x?x4`.
4950

5051
Example:
5152
```
5253
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
5354
// The dimension sizes are 4, 8, 12
54-
mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
55+
mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
5556

5657
// A device mesh cluster with 2 axes, the total device number is unknown
5758
// The first dimension size is 4 and the second is unknown
58-
mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
59+
mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
5960

6061
// A device mesh cluster with 2 axes, the total device number is unknown
6162
// The first dimension size is unknown and the second is 4
62-
mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
63+
mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
6364

6465
// A device mesh cluster with 2 axes, the number of devices along both axes
6566
// is unknown
@@ -76,7 +77,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
7677
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
7778
);
7879
let assemblyFormat = [{
79-
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
80+
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
8081
attr-dict
8182
}];
8283
let extraClassDeclaration = [{
@@ -88,7 +89,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
8889
template <typename OutIt>
8990
void canonicalDimSizes(OutIt outIt) {
9091
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
91-
std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
92+
std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
9293
}
9394
}];
9495
let hasVerifier = 1;
@@ -210,7 +211,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
210211

211212
Example:
212213
```mlir
213-
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
214+
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
214215
...
215216
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
216217
: tensor<2x2xi8> -> tensor<2x4xi8>
@@ -295,7 +296,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
295296

296297
Example:
297298
```
298-
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
299+
mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
299300
...
300301
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
301302
split_axis = 0 concat_axis = 0
@@ -527,7 +528,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
527528
across the device group.
528529
Example:
529530
```
530-
mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
531+
mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
531532
...
532533
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
533534
reduction = <max> scatter_axis = 0

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ class AsmPrinter {
226226
printArrowTypeList(results);
227227
}
228228

229+
void printDimensionList(ArrayRef<int64_t> shape);
230+
229231
/// Class used to automatically end a cyclic region on destruction.
230232
class CyclicPrintReset {
231233
public:
@@ -1762,6 +1764,17 @@ class OpAsmDialectInterface
17621764
const SetVector<AsmDialectResourceHandle> &referencedResources,
17631765
AsmResourceBuilder &builder) const {}
17641766
};
1767+
1768+
//===--------------------------------------------------------------------===//
1769+
// Custom printers and parsers.
1770+
//===--------------------------------------------------------------------===//
1771+
1772+
// Handles custom<DimensionList>(...) in TableGen.
1773+
void printDimensionList(OpAsmPrinter &printer, Operation *op,
1774+
ArrayRef<int64_t> dimensions);
1775+
ParseResult parseDimensionList(OpAsmParser &parser,
1776+
DenseI64ArrayAttr &dimensions);
1777+
17651778
} // namespace mlir
17661779

17671780
//===--------------------------------------------------------------------===//

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
5858
return vec;
5959
}
6060

61-
template <typename DimSize>
62-
static bool isMeshDimensionDynamic(DimSize size) {
63-
return size <= DimSize(0);
64-
}
65-
6661
using MeshAxis = int16_t;
6762

6863
namespace {
@@ -161,9 +156,9 @@ LogicalResult ClusterOp::verify() {
161156
"rank of dim_sizes is not expected to be larger than rank of cluster");
162157

163158
for (int64_t dimSize : dimSizes) {
164-
if (dimSize < 0)
165-
return emitOpError(
166-
"dimension size of a mesh cluster is expected to be non-negative");
159+
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
160+
return emitOpError("dimension size of a mesh cluster is expected to be "
161+
"non-negative or dynamic");
167162
}
168163

169164
return success();
@@ -316,7 +311,7 @@ static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
316311
int64_t res = 1;
317312

318313
for (MeshAxis axis : meshAxes) {
319-
if (isMeshDimensionDynamic(meshShape[axis])) {
314+
if (ShapedType::isDynamic(meshShape[axis])) {
320315
return ShapedType::kDynamic;
321316
}
322317
assert(size_t(axis) < meshShape.size());

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "mlir/IR/AsmState.h"
1717
#include "mlir/IR/Attributes.h"
1818
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/BuiltinAttributes.h"
1920
#include "mlir/IR/BuiltinDialect.h"
21+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2022
#include "mlir/IR/BuiltinTypes.h"
2123
#include "mlir/IR/Dialect.h"
2224
#include "mlir/IR/DialectImplementation.h"
@@ -27,6 +29,7 @@
2729
#include "mlir/IR/Operation.h"
2830
#include "mlir/IR/Verifier.h"
2931
#include "llvm/ADT/APFloat.h"
32+
#include "llvm/ADT/ArrayRef.h"
3033
#include "llvm/ADT/DenseMap.h"
3134
#include "llvm/ADT/MapVector.h"
3235
#include "llvm/ADT/STLExtras.h"
@@ -44,6 +47,7 @@
4447
#include "llvm/Support/SaveAndRestore.h"
4548
#include "llvm/Support/Threading.h"
4649
#include "llvm/Support/raw_ostream.h"
50+
#include <type_traits>
4751

4852
#include <optional>
4953
#include <tuple>
@@ -425,6 +429,8 @@ class AsmPrinter::Impl {
425429

426430
void popCyclicPrinting();
427431

432+
void printDimensionList(ArrayRef<int64_t> shape);
433+
428434
protected:
429435
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
430436
ArrayRef<StringRef> elidedAttrs = {},
@@ -1860,6 +1866,20 @@ class AsmStateImpl {
18601866
// Allow direct access to the impl fields.
18611867
friend AsmState;
18621868
};
1869+
1870+
template <typename Range>
1871+
void printDimensionList(raw_ostream &stream, Range &&shape) {
1872+
llvm::interleave(
1873+
shape, stream,
1874+
[&stream](const auto &dimSize) {
1875+
if (ShapedType::isDynamic(dimSize))
1876+
stream << "?";
1877+
else
1878+
stream << dimSize;
1879+
},
1880+
"x");
1881+
}
1882+
18631883
} // namespace detail
18641884
} // namespace mlir
18651885

@@ -2576,13 +2596,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
25762596
})
25772597
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
25782598
os << "tensor<";
2579-
for (int64_t dim : tensorTy.getShape()) {
2580-
if (ShapedType::isDynamic(dim))
2581-
os << '?';
2582-
else
2583-
os << dim;
2599+
printDimensionList(tensorTy.getShape());
2600+
if (!tensorTy.getShape().empty())
25842601
os << 'x';
2585-
}
25862602
printType(tensorTy.getElementType());
25872603
// Only print the encoding attribute value if set.
25882604
if (tensorTy.getEncoding()) {
@@ -2598,13 +2614,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
25982614
})
25992615
.Case<MemRefType>([&](MemRefType memrefTy) {
26002616
os << "memref<";
2601-
for (int64_t dim : memrefTy.getShape()) {
2602-
if (ShapedType::isDynamic(dim))
2603-
os << '?';
2604-
else
2605-
os << dim;
2617+
printDimensionList(memrefTy.getShape());
2618+
if (!memrefTy.getShape().empty())
26062619
os << 'x';
2607-
}
26082620
printType(memrefTy.getElementType());
26092621
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
26102622
if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
@@ -2735,6 +2747,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
27352747

27362748
void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
27372749

2750+
void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
2751+
detail::printDimensionList(os, shape);
2752+
}
2753+
27382754
//===--------------------------------------------------------------------===//
27392755
// AsmPrinter
27402756
//===--------------------------------------------------------------------===//
@@ -2800,6 +2816,10 @@ void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
28002816
impl->printResourceHandle(resource);
28012817
}
28022818

2819+
void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
2820+
detail::printDimensionList(getStream(), shape);
2821+
}
2822+
28032823
LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
28042824
return impl->pushCyclicPrinting(opaquePointer);
28052825
}
@@ -3911,3 +3931,47 @@ void Block::printAsOperand(raw_ostream &os, AsmState &state) {
39113931
OperationPrinter printer(os, state.getImpl());
39123932
printer.printBlockName(this);
39133933
}
3934+
3935+
//===--------------------------------------------------------------------===//
3936+
// Custom printers
3937+
//===--------------------------------------------------------------------===//
3938+
namespace mlir {
3939+
3940+
void printDimensionList(OpAsmPrinter &printer, Operation *op,
3941+
ArrayRef<int64_t> dimensions) {
3942+
if (dimensions.empty())
3943+
printer << "[";
3944+
printer.printDimensionList(dimensions);
3945+
if (dimensions.empty())
3946+
printer << "]";
3947+
}
3948+
3949+
ParseResult parseDimensionList(OpAsmParser &parser,
3950+
DenseI64ArrayAttr &dimensions) {
3951+
// Empty list case denoted by "[]".
3952+
if (succeeded(parser.parseOptionalLSquare())) {
3953+
if (failed(parser.parseRSquare())) {
3954+
return parser.emitError(parser.getCurrentLocation())
3955+
<< "Failed parsing dimension list.";
3956+
}
3957+
dimensions =
3958+
DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>());
3959+
return success();
3960+
}
3961+
3962+
// Non-empty list case.
3963+
SmallVector<int64_t> shapeArr;
3964+
if (failed(parser.parseDimensionList(shapeArr, true, false))) {
3965+
return parser.emitError(parser.getCurrentLocation())
3966+
<< "Failed parsing dimension list.";
3967+
}
3968+
if (shapeArr.empty()) {
3969+
return parser.emitError(parser.getCurrentLocation())
3970+
<< "Failed parsing dimension list. Did you mean an empty list? It "
3971+
"must be denoted by \"[]\".";
3972+
}
3973+
dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
3974+
return success();
3975+
}
3976+
3977+
} // namespace mlir

mlir/test/Dialect/Mesh/canonicalization.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt --canonicalize %s | FileCheck %s
22

3-
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
3+
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
44

55
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
66
func.func @all_reduce_empty_mesh_axes(

0 commit comments

Comments
 (0)