Skip to content

Commit 581c810

Browse files
committed
Add type constraints
1 parent 51d70a0 commit 581c810

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -907,15 +907,14 @@ def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
907907
let genSpecializedAttr = 0;
908908
let cppNamespace = "::mlir::gpu";
909909
}
910+
911+
def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
912+
910913
def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
911914
"all_reduce_op">;
912915

913916
def GPU_AllReduceOp : GPU_Op<"all_reduce",
914-
[SameOperandsAndResultType, IsolatedFromAbove]>,
915-
Arguments<(ins AnyType:$value,
916-
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
917-
UnitAttr:$uniform)>,
918-
Results<(outs AnyType)> {
917+
[SameOperandsAndResultType, IsolatedFromAbove]> {
919918
let summary = "Reduce values among workgroup.";
920919
let description = [{
921920
The `all_reduce` op reduces the value of every work item across a local
@@ -943,6 +942,14 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
943942
If `uniform` flag is set either none or all work items of a workgroup
944943
need to execute this op in convergence.
945944
}];
945+
946+
let arguments = (ins
947+
AnyIntegerOrFloat:$value,
948+
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
949+
UnitAttr:$uniform
950+
);
951+
let results = (outs AnyIntegerOrFloat:$result);
952+
946953
let regions = (region AnyRegion:$body);
947954
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
948955
(`uniform` $uniform^)? $body attr-dict
@@ -952,12 +959,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
952959
let hasRegionVerifier = 1;
953960
}
954961

955-
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
956-
[SameOperandsAndResultType]>,
957-
Arguments<(ins AnyType:$value,
958-
GPU_AllReduceOperationAttr:$op,
959-
UnitAttr:$uniform)>,
960-
Results<(outs AnyType)> {
962+
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
961963
let summary = "Reduce values among subgroup.";
962964
let description = [{
963965
The `subgroup_reduce` op reduces the value of every work item across a
@@ -977,6 +979,14 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
977979
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
978980
`maximumf`
979981
}];
982+
983+
let arguments = (ins
984+
AnyIntegerOrFloat:$value,
985+
GPU_AllReduceOperationAttr:$op,
986+
UnitAttr:$uniform
987+
);
988+
let results = (outs AnyIntegerOrFloat:$result);
989+
980990
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
981991
(`uniform` $uniform^)? attr-dict
982992
`:` functional-type(operands, results) }];

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3434
!::llvm::cast<VectorType>($_self).isScalable()}]>;
3535

3636
// Whether a type is a scalable VectorType.
37-
def IsVectorTypeWithAnyDimScalablePred
37+
def IsVectorTypeWithAnyDimScalablePred
3838
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3939
::llvm::cast<VectorType>($_self).isScalable()}]>;
4040

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ module attributes {gpu.container_module} {
210210

211211
// -----
212212

213+
func.func @reduce_bad_type(%arg0 : vector<4xf32>) {
214+
// expected-error@+1 {{'gpu.all_reduce' op operand #0 must be Integer or Float}}
215+
%res = gpu.all_reduce add %arg0 {} : (vector<4xf32>) -> vector<4xf32>
216+
return
217+
}
218+
219+
// -----
220+
213221
func.func @reduce_no_op_no_body(%arg0 : f32) {
214222
// expected-error@+1 {{expected either an op attribute or a non-empty body}}
215223
%res = "gpu.all_reduce"(%arg0) ({}) : (f32) -> (f32)
@@ -325,6 +333,14 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
325333

326334
// -----
327335

336+
func.func @subgroup_reduce_bad_type(%arg0 : vector<2xf32>) {
337+
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float}}
338+
%res = gpu.subgroup_reduce add %arg0 : (vector<2xf32>) -> vector<2xf32>
339+
return
340+
}
341+
342+
// -----
343+
328344
func.func @subgroup_reduce_invalid_op_type_and(%arg0 : f32) {
329345
// expected-error@+1 {{`and` reduction operation is not compatible with type 'f32'}}
330346
%res = gpu.subgroup_reduce and %arg0 : (f32) -> (f32)

0 commit comments

Comments
 (0)