Skip to content

Commit dff2f59

Browse files
authored
[mlir][mesh] Add TableGen deffinitions of more collective ops (#73842)
Add definitions for broadcast, gather, receive, reduce, scatter, send and shift.
1 parent 01e40a8 commit dff2f59

File tree

3 files changed

+409
-0
lines changed

3 files changed

+409
-0
lines changed

mlir/docs/Dialects/Mesh.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ explanation.
1515
The main addition is that the collectives in this dialect have mesh
1616
semantics.
1717

18+
### Device groups
1819
The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
1920
axes that partition the devices into disjoint groups.
2021
The collective operation is performed between devices in the same group.
2122
Devices that have the same coordinates outside of axes `mesh_axes` are in the
2223
same group.
24+
A group is described by its multi-index along the axes outside of `mesh_axes`.
2325
For example if we have a device mesh of size `2x3x4x5` and the partition mesh
2426
axes list is `[0, 1]` then devices are partitioned into the groups
2527
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
28+
The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
2629
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
2730
Device (1, 0, 2, 4) will be in another group.
2831
Some collective operations like all-to-all and all-gather care about the
@@ -33,6 +36,17 @@ The axes are ordered from outer to inner.
3336
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
3437
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
3538

39+
### In-group Device
40+
Some operations like `broadcast`, `scatter` and `send` specify devices in each
41+
device-group.
42+
These devices are represented with their multi-index over the mesh axes that
43+
are not constant within a device group.
44+
These are the axes specified by `mesh_axes` attribute.
45+
46+
For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
47+
an in-group device with `(i, j)`. Then for each group with index `g` on the
48+
second axis, the in-group device would be `(i, g, j)`.
49+
3650

3751
## Operations
3852

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
339339
let hasCanonicalizer = 1;
340340
}
341341

342+
def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
343+
AllShapesMatch<["input", "result"]>,
344+
AllElementTypesMatch<["input", "result"]>
345+
]> {
346+
let summary = "Broadcast over a device mesh.";
347+
let description = [{
348+
Broadcast the tensor on `root` to all devices in each respective group.
349+
The operation broadcasts along mesh axes `mesh_axes`.
350+
The `root` device specifies the in-group multi-index that is broadcast to
351+
all other devices in the group.
352+
353+
Example:
354+
```
355+
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
356+
357+
%1 = mesh.broadcast %0 on @mesh0
358+
mesh_axes = [0]
359+
root = [0]
360+
: (tensor<2xi8>) -> tensor<2xi8>
361+
```
362+
363+
Input:
364+
```
365+
+-------+-------+ | broadcast
366+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
367+
+-------+-------+ ↓
368+
device (1, 0) -> | | | <- device (1, 1)
369+
+-------+-------+
370+
```
371+
372+
Output:
373+
```
374+
+-------+-------+
375+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
376+
+-------+-------+
377+
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
378+
+-------+-------+
379+
```
380+
}];
381+
let arguments = !con(commonArgs, (ins
382+
AnyRankedTensor:$input,
383+
DenseI64ArrayAttr:$root,
384+
Variadic<Index>:$root_dynamic
385+
));
386+
let results = (outs
387+
AnyRankedTensor:$result
388+
);
389+
let assemblyFormat = [{
390+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
391+
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
392+
attr-dict `:` functional-type(operands, results)
393+
}];
394+
}
395+
396+
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
397+
AllRanksMatch<["input", "result"]>,
398+
AllElementTypesMatch<["input", "result"]>
399+
]> {
400+
let summary = "Gather over a device mesh.";
401+
let description = [{
402+
Gathers on device `root` along the `gather_axis` tensor axis.
403+
`root` specifies the coordinates of a device along `mesh_axes`.
404+
It uniquely identifies the root device for each device group.
405+
The result tensor on non-root devices is undefined.
406+
Using it will result in undefined behavior.
407+
408+
Example:
409+
```mlir
410+
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
411+
...
412+
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
413+
gather_axis = 1 root = [1]
414+
: (tensor<2x2xi8>) -> tensor<2x4xi8>
415+
```
416+
Input:
417+
```
418+
gather tensor
419+
axis 1
420+
------------>
421+
+-------+-------+
422+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
423+
| 3 4 | 7 8 |
424+
+-------+-------+
425+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
426+
| 11 12 | 15 16 |
427+
+-------+-------+
428+
```
429+
Result:
430+
```
431+
+-------------+
432+
| 1 2 5 6 | <- devices (0, 1)
433+
| 3 4 7 8 |
434+
+-------------+
435+
| 9 10 13 14 | <- devices (1, 1)
436+
| 11 12 15 16 |
437+
+-------------+
438+
```
439+
Devices `(0, 0)` and `(1, 0)` have undefined result.
440+
}];
441+
let arguments = !con(commonArgs, (ins
442+
AnyNon0RankedTensor:$input,
443+
IndexAttr:$gather_axis,
444+
DenseI64ArrayAttr:$root,
445+
Variadic<Index>:$root_dynamic
446+
));
447+
let results = (outs
448+
AnyNon0RankedTensor:$result
449+
);
450+
let assemblyFormat = [{
451+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
452+
`gather_axis` `=` $gather_axis
453+
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
454+
attr-dict `:` functional-type(operands, results)
455+
}];
456+
}
457+
458+
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
459+
AllShapesMatch<["input", "result"]>,
460+
AllElementTypesMatch<["input", "result"]>
461+
]> {
462+
let summary = "Send over a device mesh.";
463+
let description = [{
464+
Receive from a device within a device group.
465+
}];
466+
let arguments = !con(commonArgs, (ins
467+
AnyNon0RankedTensor:$input,
468+
OptionalAttr<DenseI64ArrayAttr>:$source,
469+
Variadic<Index>:$source_dynamic
470+
));
471+
let results = (outs
472+
AnyRankedTensor:$result
473+
);
474+
let assemblyFormat = [{
475+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
476+
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
477+
attr-dict `:` functional-type(operands, results)
478+
}];
479+
}
480+
481+
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
482+
AllShapesMatch<["input", "result"]>
483+
]> {
484+
let summary = "Reduce over a device mesh.";
485+
let description = [{
486+
Reduces on device `root` within each device group.
487+
`root` specifies the coordinates of a device along `mesh_axes`.
488+
It uniquely identifies the root device within its device group.
489+
The accumulation element type is specified by the result type and
490+
it does not need to match the input element type.
491+
The input element is converted to the result element type before
492+
performing the reduction.
493+
494+
Attributes:
495+
`reduction`: Indicates the reduction method.
496+
497+
Example:
498+
```
499+
%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
500+
reduction = <max> root = [2, 3]
501+
: (tensor<3x4xf32>) -> tensor<3x4xf64>
502+
```
503+
}];
504+
let arguments = !con(commonArgs, (ins
505+
AnyRankedTensor:$input,
506+
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
507+
DenseI64ArrayAttr:$root,
508+
Variadic<Index>:$root_dynamic
509+
));
510+
let results = (outs
511+
AnyRankedTensor:$result
512+
);
513+
let assemblyFormat = [{
514+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
515+
(`reduction` `=` $reduction^)?
516+
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
517+
attr-dict `:` functional-type(operands, results)
518+
}];
519+
}
520+
342521
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
343522
SameOperandsAndResultRank]> {
344523
let summary = "Reduce-scatter over a device mesh.";
@@ -400,4 +579,154 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
400579
let hasCanonicalizer = 1;
401580
}
402581

582+
def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
583+
AllRanksMatch<["input", "result"]>,
584+
AllElementTypesMatch<["input", "result"]>
585+
]> {
586+
let summary = "Scatter over a device mesh.";
587+
let description = [{
588+
For each device group split the input tensor on the `root` device along
589+
axis `scatter_axis` and scatter the parts across the group devices.
590+
591+
Example:
592+
```
593+
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
594+
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
595+
scatter_axis = 0
596+
root = [1]
597+
: (tensor<2x2xi8>) -> tensor<1x2xi8>
598+
```
599+
600+
Input:
601+
```
602+
device
603+
(0, 1)
604+
605+
+-------+-------+ | scatter tensor
606+
device (0, 0) -> | | | | axis 0
607+
| | | ↓
608+
+-------+-------+
609+
device (1, 0) -> | 1 2 | 5 6 |
610+
| 3 4 | 7 8 |
611+
+-------+-------+
612+
613+
device
614+
(1, 1)
615+
```
616+
617+
Result:
618+
```
619+
device
620+
(0, 1)
621+
622+
+-------+-------+
623+
device (0, 0) -> | 1 2 | 5 6 |
624+
+-------+-------+
625+
device (1, 0) -> | 3 4 | 7 8 |
626+
+-------+-------+
627+
628+
device
629+
(1, 1)
630+
```
631+
}];
632+
let arguments = !con(commonArgs, (ins
633+
AnyNon0RankedTensor:$input,
634+
IndexAttr:$scatter_axis,
635+
DenseI64ArrayAttr:$root,
636+
Variadic<Index>:$root_dynamic
637+
));
638+
let results = (outs
639+
AnyRankedTensor:$result
640+
);
641+
let assemblyFormat = [{
642+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
643+
`scatter_axis` `=` $scatter_axis
644+
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
645+
attr-dict `:` functional-type(operands, results)
646+
}];
647+
}
648+
649+
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
650+
AllShapesMatch<["input", "result"]>,
651+
AllElementTypesMatch<["input", "result"]>
652+
]> {
653+
let summary = "Send over a device mesh.";
654+
let description = [{
655+
Send from one device to another within a device group.
656+
}];
657+
let arguments = !con(commonArgs, (ins
658+
AnyNon0RankedTensor:$input,
659+
DenseI64ArrayAttr:$destination,
660+
Variadic<Index>:$destination_dynamic
661+
));
662+
let results = (outs
663+
AnyRankedTensor:$result
664+
);
665+
let assemblyFormat = [{
666+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
667+
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
668+
attr-dict `:` functional-type(operands, results)
669+
}];
670+
}
671+
672+
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
673+
SameOperandsAndResultElementType,
674+
SameOperandsAndResultShape
675+
]> {
676+
let summary = "Sift over a device mesh.";
677+
let description = [{
678+
Within each device group shift along mesh axis `shift_axis` by an offset
679+
`offset`.
680+
The result on devices that do not have a corresponding source is undefined.
681+
`shift_axis` must be one of `mesh_axes`.
682+
If the `rotate` attribute is present,
683+
instead of a shift a rotation is done.
684+
685+
Example:
686+
```
687+
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
688+
%1 = mesh.shift on @mesh0 mesh_axes = [1]
689+
shift_axis = 1 offset = 2 rotate
690+
: tensor<2xi8> -> tensor<2xi8>
691+
```
692+
693+
Input:
694+
```
695+
mesh axis 1
696+
----------->
697+
698+
+----+----+----+----+
699+
| 1 | 2 | 3 | 4 |
700+
+----+----+----+----+
701+
| 5 | 6 | 7 | 8 |
702+
+----+----+----+----+
703+
```
704+
705+
Result:
706+
```
707+
+----+----+----+----+
708+
| 3 | 4 | 1 | 2 |
709+
+----+----+----+----+
710+
| 7 | 8 | 5 | 6 |
711+
+----+----+----+----+
712+
```
713+
}];
714+
let arguments = !con(commonArgs, (ins
715+
AnyNon0RankedTensor:$input,
716+
IndexAttr:$shift_axis,
717+
I64Attr:$offset,
718+
UnitAttr:$rotate
719+
));
720+
let results = (outs
721+
AnyRankedTensor:$result
722+
);
723+
let assemblyFormat = [{
724+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
725+
`shift_axis` `=` $shift_axis
726+
`offset` `=` $offset
727+
(`rotate` $rotate^)?
728+
attr-dict `:` type($input) `->` type($result)
729+
}];
730+
}
731+
403732
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

0 commit comments

Comments
 (0)