@@ -907,15 +907,14 @@ def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
907
907
let genSpecializedAttr = 0;
908
908
let cppNamespace = "::mlir::gpu";
909
909
}
910
+
911
+ def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
912
+
910
913
def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
911
914
"all_reduce_op">;
912
915
913
916
def GPU_AllReduceOp : GPU_Op<"all_reduce",
914
- [SameOperandsAndResultType, IsolatedFromAbove]>,
915
- Arguments<(ins AnyType:$value,
916
- OptionalAttr<GPU_AllReduceOperationAttr>:$op,
917
- UnitAttr:$uniform)>,
918
- Results<(outs AnyType)> {
917
+ [SameOperandsAndResultType, IsolatedFromAbove]> {
919
918
let summary = "Reduce values among workgroup.";
920
919
let description = [{
921
920
The `all_reduce` op reduces the value of every work item across a local
@@ -943,6 +942,14 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
943
942
If `uniform` flag is set either none or all work items of a workgroup
944
943
need to execute this op in convergence.
945
944
}];
945
+
946
+ let arguments = (ins
947
+ AnyIntegerOrFloat:$value,
948
+ OptionalAttr<GPU_AllReduceOperationAttr>:$op,
949
+ UnitAttr:$uniform
950
+ );
951
+ let results = (outs AnyIntegerOrFloat:$result);
952
+
946
953
let regions = (region AnyRegion:$body);
947
954
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
948
955
(`uniform` $uniform^)? $body attr-dict
@@ -952,12 +959,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
952
959
let hasRegionVerifier = 1;
953
960
}
954
961
955
- def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
956
- [SameOperandsAndResultType]>,
957
- Arguments<(ins AnyType:$value,
958
- GPU_AllReduceOperationAttr:$op,
959
- UnitAttr:$uniform)>,
960
- Results<(outs AnyType)> {
962
+ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
961
963
let summary = "Reduce values among subgroup.";
962
964
let description = [{
963
965
The `subgroup_reduce` op reduces the value of every work item across a
@@ -977,6 +979,14 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
977
979
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
978
980
`maximumf`
979
981
}];
982
+
983
+ let arguments = (ins
984
+ AnyIntegerOrFloat:$value,
985
+ GPU_AllReduceOperationAttr:$op,
986
+ UnitAttr:$uniform
987
+ );
988
+ let results = (outs AnyIntegerOrFloat:$result);
989
+
980
990
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
981
991
(`uniform` $uniform^)? attr-dict
982
992
`:` functional-type(operands, results) }];
0 commit comments