Skip to content

Commit 4df5310

Browse files
authored
[mlir][spirv] Use assemblyFormat to define groupNonUniform op assembly (#115662)
Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out CPP interfaces. Changes: * updates the Ops defined in `SPIRVNonUniformOps.td and SPIRVGroupOps.td` to use assemblyFormat. * Removes print/parse from `GroupOps.cpp` which is now generated by assemblyFormat * Updates tests to updated format (largely using <operand> in place of "operand" and complementing type information) Issue: #73359
1 parent 0baa6a7 commit 4df5310

File tree

13 files changed

+157
-496
lines changed

13 files changed

+157
-496
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,12 @@ def SPIRV_INTELSubgroupBlockReadOp : SPIRV_IntelVendorOp<"SubgroupBlockRead", []
661661
let results = (outs
662662
SPIRV_Type:$value
663663
);
664+
665+
let hasCustomAssemblyFormat = 0;
666+
667+
let assemblyFormat = [{
668+
$ptr attr-dict `:` type($ptr) `->` type($value)
669+
}];
664670
}
665671

666672
// -----

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td

Lines changed: 45 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ class SPIRV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
2626

2727
let results = (outs
2828
SPIRV_ScalarOrVectorOf<type>:$result
29-
);
29+
);
30+
31+
let hasCustomAssemblyFormat = 0;
32+
33+
let assemblyFormat = [{
34+
$execution_scope $group_operation $value (`cluster_size``(` $cluster_size^ `)`)? attr-dict `:` type($value) (`,` type($cluster_size)^)? `->` type(results)
35+
}];
3036
}
3137

3238
// -----
@@ -318,24 +324,14 @@ def SPIRV_GroupNonUniformFAddOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
318324

319325
<!-- End of AutoGen section -->
320326

321-
```
322-
scope ::= `"Workgroup"` | `"Subgroup"`
323-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
324-
float-scalar-vector-type ::= float-type |
325-
`vector<` integer-literal `x` float-type `>`
326-
non-uniform-fadd-op ::= ssa-id `=` `spirv.GroupNonUniformFAdd` scope operation
327-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
328-
`:` float-scalar-vector-type
329-
```
330-
331327
#### Example:
332328

333329
```mlir
334330
%four = spirv.Constant 4 : i32
335331
%scalar = ... : f32
336332
%vector = ... : vector<4xf32>
337-
%0 = spirv.GroupNonUniformFAdd "Workgroup" "Reduce" %scalar : f32
338-
%1 = spirv.GroupNonUniformFAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32>
333+
%0 = spirv.GroupNonUniformFAdd <Workgroup> <Reduce> %scalar : f32 -> f32
334+
%1 = spirv.GroupNonUniformFAdd <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xf32>, i32 -> vector<4xf32>
339335
```
340336
}];
341337

@@ -378,24 +374,14 @@ def SPIRV_GroupNonUniformFMaxOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
378374

379375
<!-- End of AutoGen section -->
380376

381-
```
382-
scope ::= `"Workgroup"` | `"Subgroup"`
383-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
384-
float-scalar-vector-type ::= float-type |
385-
`vector<` integer-literal `x` float-type `>`
386-
non-uniform-fmax-op ::= ssa-id `=` `spirv.GroupNonUniformFMax` scope operation
387-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
388-
`:` float-scalar-vector-type
389-
```
390-
391377
#### Example:
392378

393379
```mlir
394380
%four = spirv.Constant 4 : i32
395381
%scalar = ... : f32
396382
%vector = ... : vector<4xf32>
397-
%0 = spirv.GroupNonUniformFMax "Workgroup" "Reduce" %scalar : f32
398-
%1 = spirv.GroupNonUniformFMax "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32>
383+
%0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %scalar : f32 -> f32
384+
%1 = spirv.GroupNonUniformFMax <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xf32>, i32 -> vector<4xf32>
399385
```
400386
}];
401387

@@ -438,24 +424,14 @@ def SPIRV_GroupNonUniformFMinOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
438424

439425
<!-- End of AutoGen section -->
440426

441-
```
442-
scope ::= `"Workgroup"` | `"Subgroup"`
443-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
444-
float-scalar-vector-type ::= float-type |
445-
`vector<` integer-literal `x` float-type `>`
446-
non-uniform-fmin-op ::= ssa-id `=` `spirv.GroupNonUniformFMin` scope operation
447-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
448-
`:` float-scalar-vector-type
449-
```
450-
451427
#### Example:
452428

453429
```mlir
454430
%four = spirv.Constant 4 : i32
455431
%scalar = ... : f32
456432
%vector = ... : vector<4xf32>
457-
%0 = spirv.GroupNonUniformFMin "Workgroup" "Reduce" %scalar : f32
458-
%1 = spirv.GroupNonUniformFMin "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32>
433+
%0 = spirv.GroupNonUniformFMin <Workgroup> <Reduce> %scalar : f32 -> i32
434+
%1 = spirv.GroupNonUniformFMin <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xf32>, i32 -> vector<4xf32>
459435
```
460436
}];
461437

@@ -495,24 +471,14 @@ def SPIRV_GroupNonUniformFMulOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
495471

496472
<!-- End of AutoGen section -->
497473

498-
```
499-
scope ::= `"Workgroup"` | `"Subgroup"`
500-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
501-
float-scalar-vector-type ::= float-type |
502-
`vector<` integer-literal `x` float-type `>`
503-
non-uniform-fmul-op ::= ssa-id `=` `spirv.GroupNonUniformFMul` scope operation
504-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
505-
`:` float-scalar-vector-type
506-
```
507-
508474
#### Example:
509475

510476
```mlir
511477
%four = spirv.Constant 4 : i32
512478
%scalar = ... : f32
513479
%vector = ... : vector<4xf32>
514-
%0 = spirv.GroupNonUniformFMul "Workgroup" "Reduce" %scalar : f32
515-
%1 = spirv.GroupNonUniformFMul "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32>
480+
%0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %scalar : f32 -> f32
481+
%1 = spirv.GroupNonUniformFMul <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xf32>, i32 -> vector<4xf32>
516482
```
517483
}];
518484

@@ -550,24 +516,14 @@ def SPIRV_GroupNonUniformIAddOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
550516

551517
<!-- End of AutoGen section -->
552518

553-
```
554-
scope ::= `"Workgroup"` | `"Subgroup"`
555-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
556-
integer-scalar-vector-type ::= integer-type |
557-
`vector<` integer-literal `x` integer-type `>`
558-
non-uniform-iadd-op ::= ssa-id `=` `spirv.GroupNonUniformIAdd` scope operation
559-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
560-
`:` integer-scalar-vector-type
561-
```
562-
563519
#### Example:
564520

565521
```mlir
566522
%four = spirv.Constant 4 : i32
567523
%scalar = ... : i32
568524
%vector = ... : vector<4xi32>
569-
%0 = spirv.GroupNonUniformIAdd "Workgroup" "Reduce" %scalar : i32
570-
%1 = spirv.GroupNonUniformIAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
525+
%0 = spirv.GroupNonUniformIAdd <Workgroup> <Reduce> %scalar : i32 -> i32
526+
%1 = spirv.GroupNonUniformIAdd <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
571527
```
572528
}];
573529

@@ -605,24 +561,14 @@ def SPIRV_GroupNonUniformIMulOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
605561

606562
<!-- End of AutoGen section -->
607563

608-
```
609-
scope ::= `"Workgroup"` | `"Subgroup"`
610-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
611-
integer-scalar-vector-type ::= integer-type |
612-
`vector<` integer-literal `x` integer-type `>`
613-
non-uniform-imul-op ::= ssa-id `=` `spirv.GroupNonUniformIMul` scope operation
614-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
615-
`:` integer-scalar-vector-type
616-
```
617-
618564
#### Example:
619565

620566
```mlir
621567
%four = spirv.Constant 4 : i32
622568
%scalar = ... : i32
623569
%vector = ... : vector<4xi32>
624-
%0 = spirv.GroupNonUniformIMul "Workgroup" "Reduce" %scalar : i32
625-
%1 = spirv.GroupNonUniformIMul "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
570+
%0 = spirv.GroupNonUniformIMul <Workgroup> <Reduce> %scalar : i32 -> i32
571+
%1 = spirv.GroupNonUniformIMul <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
626572
```
627573
}];
628574

@@ -662,24 +608,14 @@ def SPIRV_GroupNonUniformSMaxOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
662608

663609
<!-- End of AutoGen section -->
664610

665-
```
666-
scope ::= `"Workgroup"` | `"Subgroup"`
667-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
668-
integer-scalar-vector-type ::= integer-type |
669-
`vector<` integer-literal `x` integer-type `>`
670-
non-uniform-smax-op ::= ssa-id `=` `spirv.GroupNonUniformSMax` scope operation
671-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
672-
`:` integer-scalar-vector-type
673-
```
674-
675611
#### Example:
676612

677613
```mlir
678614
%four = spirv.Constant 4 : i32
679615
%scalar = ... : i32
680616
%vector = ... : vector<4xi32>
681-
%0 = spirv.GroupNonUniformSMax "Workgroup" "Reduce" %scalar : i32
682-
%1 = spirv.GroupNonUniformSMax "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
617+
%0 = spirv.GroupNonUniformSMax <Workgroup> <Reduce> %scalar : i32
618+
%1 = spirv.GroupNonUniformSMax <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
683619
```
684620
}];
685621

@@ -719,24 +655,14 @@ def SPIRV_GroupNonUniformSMinOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
719655

720656
<!-- End of AutoGen section -->
721657

722-
```
723-
scope ::= `"Workgroup"` | `"Subgroup"`
724-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
725-
integer-scalar-vector-type ::= integer-type |
726-
`vector<` integer-literal `x` integer-type `>`
727-
non-uniform-smin-op ::= ssa-id `=` `spirv.GroupNonUniformSMin` scope operation
728-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
729-
`:` integer-scalar-vector-type
730-
```
731-
732658
#### Example:
733659

734660
```mlir
735661
%four = spirv.Constant 4 : i32
736662
%scalar = ... : i32
737663
%vector = ... : vector<4xi32>
738-
%0 = spirv.GroupNonUniformSMin "Workgroup" "Reduce" %scalar : i32
739-
%1 = spirv.GroupNonUniformSMin "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
664+
%0 = spirv.GroupNonUniformSMin <Workgroup> <Reduce> %scalar : i32 -> i32
665+
%1 = spirv.GroupNonUniformSMin <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
740666
```
741667
}];
742668

@@ -992,24 +918,14 @@ def SPIRV_GroupNonUniformUMaxOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
992918

993919
<!-- End of AutoGen section -->
994920

995-
```
996-
scope ::= `"Workgroup"` | `"Subgroup"`
997-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
998-
integer-scalar-vector-type ::= integer-type |
999-
`vector<` integer-literal `x` integer-type `>`
1000-
non-uniform-umax-op ::= ssa-id `=` `spirv.GroupNonUniformUMax` scope operation
1001-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
1002-
`:` integer-scalar-vector-type
1003-
```
1004-
1005921
#### Example:
1006922

1007923
```mlir
1008924
%four = spirv.Constant 4 : i32
1009925
%scalar = ... : i32
1010926
%vector = ... : vector<4xi32>
1011-
%0 = spirv.GroupNonUniformUMax "Workgroup" "Reduce" %scalar : i32
1012-
%1 = spirv.GroupNonUniformUMax "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
927+
%0 = spirv.GroupNonUniformUMax <Workgroup> <Reduce> %scalar : i32 -> i32
928+
%1 = spirv.GroupNonUniformUMax <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
1013929
```
1014930
}];
1015931

@@ -1050,24 +966,14 @@ def SPIRV_GroupNonUniformUMinOp : SPIRV_GroupNonUniformArithmeticOp<"GroupNonUni
1050966

1051967
<!-- End of AutoGen section -->
1052968

1053-
```
1054-
scope ::= `"Workgroup"` | `"Subgroup"`
1055-
operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
1056-
integer-scalar-vector-type ::= integer-type |
1057-
`vector<` integer-literal `x` integer-type `>`
1058-
non-uniform-umin-op ::= ssa-id `=` `spirv.GroupNonUniformUMin` scope operation
1059-
ssa-use ( `cluster_size` `(` ssa_use `)` )?
1060-
`:` integer-scalar-vector-type
1061-
```
1062-
1063969
#### Example:
1064970

1065971
```mlir
1066972
%four = spirv.Constant 4 : i32
1067973
%scalar = ... : i32
1068974
%vector = ... : vector<4xi32>
1069-
%0 = spirv.GroupNonUniformUMin "Workgroup" "Reduce" %scalar : i32
1070-
%1 = spirv.GroupNonUniformUMin "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
975+
%0 = spirv.GroupNonUniformUMin <Workgroup> <Reduce> %scalar : i32 -> i32
976+
%1 = spirv.GroupNonUniformUMin <Subgroup> <ClusteredReduce> %vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
1071977
```
1072978
}];
1073979

@@ -1113,9 +1019,9 @@ def SPIRV_GroupNonUniformBitwiseAndOp :
11131019
%four = spirv.Constant 4 : i32
11141020
%scalar = ... : i32
11151021
%vector = ... : vector<4xi32>
1116-
%0 = spirv.GroupNonUniformBitwiseAnd "Workgroup" "Reduce" %scalar : i32
1117-
%1 = spirv.GroupNonUniformBitwiseAnd "Subgroup" "ClusteredReduce"
1118-
%vector cluster_size(%four) : vector<4xi32>
1022+
%0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %scalar : i32 -> i32
1023+
%1 = spirv.GroupNonUniformBitwiseAnd <Subgroup> <ClusteredReduce>
1024+
%vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
11191025
```
11201026
}];
11211027

@@ -1163,9 +1069,9 @@ def SPIRV_GroupNonUniformBitwiseOrOp :
11631069
%four = spirv.Constant 4 : i32
11641070
%scalar = ... : i32
11651071
%vector = ... : vector<4xi32>
1166-
%0 = spirv.GroupNonUniformBitwiseOr "Workgroup" "Reduce" %scalar : i32
1167-
%1 = spirv.GroupNonUniformBitwiseOr "Subgroup" "ClusteredReduce"
1168-
%vector cluster_size(%four) : vector<4xi32>
1072+
%0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %scalar : i32 -> i32
1073+
%1 = spirv.GroupNonUniformBitwiseOr <Subgroup> <ClusteredReduce>
1074+
%vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
11691075
```
11701076
}];
11711077

@@ -1213,9 +1119,9 @@ def SPIRV_GroupNonUniformBitwiseXorOp :
12131119
%four = spirv.Constant 4 : i32
12141120
%scalar = ... : i32
12151121
%vector = ... : vector<4xi32>
1216-
%0 = spirv.GroupNonUniformBitwiseXor "Workgroup" "Reduce" %scalar : i32
1217-
%1 = spirv.GroupNonUniformBitwiseXor "Subgroup" "ClusteredReduce"
1218-
%vector cluster_size(%four) : vector<4xi32>
1122+
%0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %scalar : i32 -> i32
1123+
%1 = spirv.GroupNonUniformBitwiseXor <Subgroup> <ClusteredReduce>
1124+
%vector cluster_size(%four) : vector<4xi32>, i32 -> vector<4xi32>
12191125
```
12201126
}];
12211127

@@ -1263,9 +1169,9 @@ def SPIRV_GroupNonUniformLogicalAndOp :
12631169
%four = spirv.Constant 4 : i32
12641170
%scalar = ... : i1
12651171
%vector = ... : vector<4xi1>
1266-
%0 = spirv.GroupNonUniformLogicalAnd "Workgroup" "Reduce" %scalar : i1
1267-
%1 = spirv.GroupNonUniformLogicalAnd "Subgroup" "ClusteredReduce"
1268-
%vector cluster_size(%four) : vector<4xi1>
1172+
%0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %scalar : i1 -> i1
1173+
%1 = spirv.GroupNonUniformLogicalAnd <Subgroup> <ClusteredReduce>
1174+
%vector cluster_size(%four) : vector<4xi1>, i32 -> vector<4xi1>
12691175
```
12701176
}];
12711177

@@ -1313,9 +1219,9 @@ def SPIRV_GroupNonUniformLogicalOrOp :
13131219
%four = spirv.Constant 4 : i32
13141220
%scalar = ... : i1
13151221
%vector = ... : vector<4xi1>
1316-
%0 = spirv.GroupNonUniformLogicalOr "Workgroup" "Reduce" %scalar : i1
1317-
%1 = spirv.GroupNonUniformLogicalOr "Subgroup" "ClusteredReduce"
1318-
%vector cluster_size(%four) : vector<4xi1>
1222+
%0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %scalar : i1 -> i1
1223+
%1 = spirv.GroupNonUniformLogicalOr <Subgroup> <ClusteredReduce>
1224+
%vector cluster_size(%four) : vector<4xi1>, i32 -> vector<4xi1>
13191225
```
13201226
}];
13211227

@@ -1363,9 +1269,9 @@ def SPIRV_GroupNonUniformLogicalXorOp :
13631269
%four = spirv.Constant 4 : i32
13641270
%scalar = ... : i1
13651271
%vector = ... : vector<4xi1>
1366-
%0 = spirv.GroupNonUniformLogicalXor "Workgroup" "Reduce" %scalar : i1
1367-
%1 = spirv.GroupNonUniformLogicalXor "Subgroup" "ClusteredReduce"
1368-
%vector cluster_size(%four) : vector<4xi>
1272+
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %scalar : i1 -> i1
1273+
%1 = spirv.GroupNonUniformLogicalXor <Subgroup> <ClusteredReduce>
1274+
%vector cluster_size(%four) : vector<4xi1>, i32 -> vector<4xi1>
13691275
```
13701276
}];
13711277

0 commit comments

Comments
 (0)