-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] Add TableGen deffinitions of more collective ops #73842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ | |
let hasCanonicalizer = 1; | ||
} | ||
|
||
def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [ | ||
AllShapesMatch<["input", "result"]>, | ||
AllElementTypesMatch<["input", "result"]> | ||
]> { | ||
let summary = "Broadcast over a device mesh."; | ||
let description = [{ | ||
Broadcast the tensor on `root` to all devices in each respective group. | ||
The operation broadcasts along mesh axes `mesh_axes`. | ||
The `root` device specifies the in-group multi-index that is broadcast to | ||
all other devices in the group. | ||
|
||
Example: | ||
``` | ||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2]) | ||
|
||
%1 = mesh.broadcast %0 on @mesh0 | ||
mesh_axes = [0] | ||
root = [0] | ||
: (tensor<2xi8>) -> tensor<2xi8> | ||
``` | ||
|
||
Input: | ||
``` | ||
+-------+-------+ | broadcast | ||
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 | ||
+-------+-------+ ↓ | ||
device (1, 0) -> | | | <- device (1, 1) | ||
+-------+-------+ | ||
``` | ||
|
||
Output: | ||
``` | ||
+-------+-------+ | ||
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | ||
+-------+-------+ | ||
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1) | ||
+-------+-------+ | ||
``` | ||
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyRankedTensor:$input, | ||
DenseI64ArrayAttr:$root, | ||
Variadic<Index>:$root_dynamic | ||
)); | ||
let results = (outs | ||
AnyRankedTensor:$result | ||
); | ||
let assemblyFormat = [{ | ||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? | ||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root) | ||
attr-dict `:` functional-type(operands, results) | ||
}]; | ||
} | ||
|
||
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [ | ||
AllRanksMatch<["input", "result"]>, | ||
AllElementTypesMatch<["input", "result"]> | ||
]> { | ||
let summary = "Gather over a device mesh."; | ||
let description = [{ | ||
Gathers on device `root` along the `gather_axis` tensor axis. | ||
`root` specifies the coordinates of a device along `mesh_axes`. | ||
It uniquely identifies the root device for each device group. | ||
The result tensor on non-root devices is undefined. | ||
Using it will result in undefined behavior. | ||
|
||
Example: | ||
```mlir | ||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2]) | ||
... | ||
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1] | ||
gather_axis = 1 root = [1] | ||
: (tensor<2x2xi8>) -> tensor<2x4xi8> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the output still There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are doing SPMD there seems to be no other choice. We can say that the results on the other devices is undefined. I don't really like to introduce undefined behavior, maybe we could say that it should be 0-filled, but then we force the runtime to make good on that promise and there should really be no programs that touch the result other than the |
||
``` | ||
Input: | ||
``` | ||
gather tensor | ||
axis 1 | ||
------------> | ||
+-------+-------+ | ||
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | ||
| 3 4 | 7 8 | | ||
+-------+-------+ | ||
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | ||
| 11 12 | 15 16 | | ||
+-------+-------+ | ||
``` | ||
Result: | ||
``` | ||
+-------------+ | ||
| 1 2 5 6 | <- devices (0, 1) | ||
| 3 4 7 8 | | ||
+-------------+ | ||
| 9 10 13 14 | <- devices (1, 1) | ||
| 11 12 15 16 | | ||
+-------------+ | ||
``` | ||
Devices `(0, 0)` and `(1, 0)` have undefined result. | ||
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyNon0RankedTensor:$input, | ||
IndexAttr:$gather_axis, | ||
DenseI64ArrayAttr:$root, | ||
Variadic<Index>:$root_dynamic | ||
)); | ||
let results = (outs | ||
AnyNon0RankedTensor:$result | ||
); | ||
let assemblyFormat = [{ | ||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? | ||
`gather_axis` `=` $gather_axis | ||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root) | ||
attr-dict `:` functional-type(operands, results) | ||
}]; | ||
} | ||
|
||
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [ | ||
AllShapesMatch<["input", "result"]>, | ||
AllElementTypesMatch<["input", "result"]> | ||
]> { | ||
let summary = "Send over a device mesh."; | ||
let description = [{ | ||
Receive from a device within a device group. | ||
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyNon0RankedTensor:$input, | ||
OptionalAttr<DenseI64ArrayAttr>:$source, | ||
Variadic<Index>:$source_dynamic | ||
)); | ||
let results = (outs | ||
AnyRankedTensor:$result | ||
); | ||
let assemblyFormat = [{ | ||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? | ||
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)? | ||
attr-dict `:` functional-type(operands, results) | ||
}]; | ||
} | ||
|
||
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [ | ||
AllShapesMatch<["input", "result"]> | ||
]> { | ||
let summary = "Reduce over a device mesh."; | ||
let description = [{ | ||
Reduces on device `root` within each device group. | ||
`root` specifies the coordinates of a device along `mesh_axes`. | ||
It uniquely identifies the root device within its device group. | ||
The accumulation element type is specified by the result type and | ||
it does not need to match the input element type. | ||
The input element is converted to the result element type before | ||
performing the reduction. | ||
|
||
Attributes: | ||
`reduction`: Indicates the reduction method. | ||
|
||
Example: | ||
``` | ||
%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0] | ||
reduction = <max> root = [2, 3] | ||
: (tensor<3x4xf32>) -> tensor<3x4xf64> | ||
``` | ||
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyRankedTensor:$input, | ||
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction, | ||
DenseI64ArrayAttr:$root, | ||
Variadic<Index>:$root_dynamic | ||
)); | ||
let results = (outs | ||
AnyRankedTensor:$result | ||
); | ||
let assemblyFormat = [{ | ||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? | ||
(`reduction` `=` $reduction^)? | ||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root) | ||
attr-dict `:` functional-type(operands, results) | ||
}]; | ||
} | ||
|
||
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [ | ||
SameOperandsAndResultRank]> { | ||
let summary = "Reduce-scatter over a device mesh."; | ||
|
@@ -400,4 +579,154 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", | |
let hasCanonicalizer = 1; | ||
} | ||
|
||
def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [ | ||
AllRanksMatch<["input", "result"]>, | ||
AllElementTypesMatch<["input", "result"]> | ||
]> { | ||
let summary = "Scatter over a device mesh."; | ||
let description = [{ | ||
For each device group split the input tensor on the `root` device along | ||
axis `scatter_axis` and scatter the parts across the group devices. | ||
|
||
Example: | ||
``` | ||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2]) | ||
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0] | ||
scatter_axis = 0 | ||
root = [1] | ||
: (tensor<2x2xi8>) -> tensor<1x2xi8> | ||
``` | ||
|
||
Input: | ||
``` | ||
device | ||
(0, 1) | ||
↓ | ||
+-------+-------+ | scatter tensor | ||
device (0, 0) -> | | | | axis 0 | ||
| | | ↓ | ||
+-------+-------+ | ||
device (1, 0) -> | 1 2 | 5 6 | | ||
| 3 4 | 7 8 | | ||
+-------+-------+ | ||
↑ | ||
device | ||
(1, 1) | ||
``` | ||
|
||
Result: | ||
``` | ||
device | ||
(0, 1) | ||
↓ | ||
+-------+-------+ | ||
device (0, 0) -> | 1 2 | 5 6 | | ||
+-------+-------+ | ||
device (1, 0) -> | 3 4 | 7 8 | | ||
+-------+-------+ | ||
↑ | ||
device | ||
(1, 1) | ||
``` | ||
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyNon0RankedTensor:$input, | ||
IndexAttr:$scatter_axis, | ||
DenseI64ArrayAttr:$root, | ||
Variadic<Index>:$root_dynamic | ||
)); | ||
let results = (outs | ||
AnyRankedTensor:$result | ||
); | ||
let assemblyFormat = [{ | ||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? | ||
`scatter_axis` `=` $scatter_axis | ||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root) | ||
attr-dict `:` functional-type(operands, results) | ||
}]; | ||
} | ||
|
||
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [ | ||
AllShapesMatch<["input", "result"]>, | ||
AllElementTypesMatch<["input", "result"]> | ||
]> { | ||
let summary = "Send over a device mesh."; | ||
let description = [{ | ||
Send from one device to another within a device group. | ||
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyNon0RankedTensor:$input, | ||
DenseI64ArrayAttr:$destination, | ||
Variadic<Index>:$destination_dynamic | ||
)); | ||
let results = (outs | ||
AnyRankedTensor:$result | ||
); | ||
let assemblyFormat = [{ | ||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? | ||
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination) | ||
attr-dict `:` functional-type(operands, results) | ||
}]; | ||
} | ||
|
||
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ | ||
SameOperandsAndResultElementType, | ||
SameOperandsAndResultShape | ||
]> { | ||
let summary = "Sift over a device mesh."; | ||
let description = [{ | ||
Within each device group shift along mesh axis `shift_axis` by an offset | ||
`offset`. | ||
The result on devices that do not have a corresponding source is undefined. | ||
`shift_axis` must be one of `mesh_axes`. | ||
If the `rotate` attribute is present, | ||
instead of a shift a rotation is done. | ||
|
||
Example: | ||
``` | ||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) | ||
%1 = mesh.shift on @mesh0 mesh_axes = [1] | ||
shift_axis = 1 offset = 2 rotate | ||
: tensor<2xi8> -> tensor<2xi8> | ||
``` | ||
|
||
Input: | ||
``` | ||
mesh axis 1 | ||
-----------> | ||
|
||
+----+----+----+----+ | ||
| 1 | 2 | 3 | 4 | | ||
+----+----+----+----+ | ||
| 5 | 6 | 7 | 8 | | ||
+----+----+----+----+ | ||
``` | ||
|
||
Result: | ||
``` | ||
+----+----+----+----+ | ||
| 3 | 4 | 1 | 2 | | ||
+----+----+----+----+ | ||
| 7 | 8 | 5 | 6 | | ||
+----+----+----+----+ | ||
``` | ||
}]; | ||
let arguments = !con(commonArgs, (ins | ||
AnyNon0RankedTensor:$input, | ||
IndexAttr:$shift_axis, | ||
I64Attr:$offset, | ||
UnitAttr:$rotate | ||
)); | ||
let results = (outs | ||
AnyRankedTensor:$result | ||
); | ||
let assemblyFormat = [{ | ||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? | ||
`shift_axis` `=` $shift_axis | ||
`offset` `=` $offset | ||
(`rotate` $rotate^)? | ||
attr-dict `:` type($input) `->` type($result) | ||
}]; | ||
} | ||
|
||
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use
AllTypesMatch
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about that, but can't there be another tensor type? Or maybe the result has something in the
encoding
field, then the program would be considered incorrect.