Skip to content

Commit a682860

Browse files
authored
[mlir][spirv] Add GroupNonUniformVote instructions (#141294)
Adds three SPIRV instructions under the `GroupNonUniformVote` capability: - OpGroupNonUniformAll - OpGroupNonUniformAny - OpGroupNonUniformAllEqual
1 parent c41a4a8 commit a682860

File tree

4 files changed

+287
-19
lines changed

4 files changed

+287
-19
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4464,6 +4464,9 @@ def SPIRV_OC_OpGroupSMax : I32EnumAttrCase<"OpGroupSMax", 2
44644464
def SPIRV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
44654465
def SPIRV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
44664466
def SPIRV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
4467+
def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
4468+
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
4469+
def SPIRV_OC_OpGroupNonUniformAllEqual : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
44674470
def SPIRV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
44684471
def SPIRV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
44694472
def SPIRV_OC_OpGroupNonUniformBallotBitCount : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
@@ -4489,8 +4492,8 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
44894492
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
44904493
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
44914494
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
4492-
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
44934495
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
4496+
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
44944497
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
44954498
def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>;
44964499
def SPIRV_OC_OpSUDot : I32EnumAttrCase<"OpSUDot", 4452>;
@@ -4581,11 +4584,13 @@ def SPIRV_OpcodeAttr :
45814584
SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
45824585
SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
45834586
SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
4584-
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable,
4585-
SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, SPIRV_OC_OpGroupFAdd,
4586-
SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, SPIRV_OC_OpGroupSMin,
4587-
SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, SPIRV_OC_OpGroupSMax,
4588-
SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed, SPIRV_OC_OpGroupNonUniformElect,
4587+
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
4588+
SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
4589+
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
4590+
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
4591+
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
4592+
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
4593+
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
45894594
SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
45904595
SPIRV_OC_OpGroupNonUniformBallotBitCount,
45914596
SPIRV_OC_OpGroupNonUniformBallotFindLSB,
@@ -4599,19 +4604,18 @@ def SPIRV_OpcodeAttr :
45994604
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
46004605
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
46014606
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
4602-
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
4603-
SPIRV_OC_OpSubgroupBallotKHR,
4604-
SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
4605-
SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
4606-
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
4607-
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
4608-
SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT,
4609-
SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL,
4610-
SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR,
4611-
SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL,
4612-
SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL,
4613-
SPIRV_OC_OpControlBarrierWaitINTEL, SPIRV_OC_OpGroupIMulKHR,
4614-
SPIRV_OC_OpGroupFMulKHR
4607+
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
4608+
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
4609+
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
4610+
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
4611+
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
4612+
SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
4613+
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
4614+
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
4615+
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
4616+
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
4617+
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
4618+
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
46154619
]>;
46164620

46174621
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,4 +1435,177 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
14351435

14361436
// -----
14371437

1438+
def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [
1439+
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
1440+
]> {
1441+
let summary = [{
1442+
Evaluates a predicate for all tangled invocations within the Execution
1443+
scope, resulting in true if predicate evaluates to true for all tangled
1444+
invocations within the Execution scope, otherwise the result is false.
1445+
}];
1446+
1447+
let description = [{
1448+
Result Type must be a Boolean type.
1449+
1450+
Execution is the scope defining the scope restricted tangle affected by
1451+
this command. It must be Subgroup.
1452+
1453+
Predicate must be a Boolean type.
1454+
1455+
An invocation will not execute a dynamic instance of this instruction
1456+
(X') until all invocations in its scope restricted tangle have executed
1457+
all dynamic instances that are program-ordered before X'.
1458+
1459+
<!-- End of AutoGen section -->
1460+
1461+
#### Example:
1462+
1463+
```mlir
1464+
%predicate = ... : i1
1465+
%0 = spirv.GroupNonUniformAll "Subgroup" %predicate : i1
1466+
```
1467+
}];
1468+
1469+
let availability = [
1470+
MinVersion<SPIRV_V_1_3>,
1471+
MaxVersion<SPIRV_V_1_6>,
1472+
Extension<[]>,
1473+
Capability<[SPIRV_C_GroupNonUniformVote]>
1474+
];
1475+
1476+
let arguments = (ins
1477+
SPIRV_ScopeAttr:$execution_scope,
1478+
SPIRV_Bool:$predicate
1479+
);
1480+
1481+
let results = (outs
1482+
SPIRV_Bool:$result
1483+
);
1484+
1485+
let hasVerifier = 0;
1486+
1487+
let assemblyFormat = [{
1488+
$execution_scope $predicate attr-dict `:` type($result)
1489+
}];
1490+
}
1491+
1492+
// -----
1493+
1494+
def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [
1495+
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
1496+
]> {
1497+
let summary = [{
1498+
Evaluates a predicate for all tangled invocations within the Execution
1499+
scope, resulting in true if predicate evaluates to true for any tangled
1500+
invocations within the Execution scope, otherwise the result is false.
1501+
}];
1502+
1503+
let description = [{
1504+
Result Type must be a Boolean type.
1505+
1506+
Execution is the scope defining the scope restricted tangle affected by
1507+
this command. It must be Subgroup.
1508+
1509+
Predicate must be a Boolean type.
1510+
1511+
An invocation will not execute a dynamic instance of this instruction
1512+
(X') until all invocations in its scope restricted tangle have executed
1513+
all dynamic instances that are program-ordered before X'.
1514+
1515+
<!-- End of AutoGen section -->
1516+
1517+
#### Example:
1518+
1519+
```mlir
1520+
%predicate = ... : i1
1521+
%0 = spirv.GroupNonUniformAny "Subgroup" %predicate : i1
1522+
```
1523+
}];
1524+
1525+
let availability = [
1526+
MinVersion<SPIRV_V_1_3>,
1527+
MaxVersion<SPIRV_V_1_6>,
1528+
Extension<[]>,
1529+
Capability<[SPIRV_C_GroupNonUniformVote]>
1530+
];
1531+
1532+
let arguments = (ins
1533+
SPIRV_ScopeAttr:$execution_scope,
1534+
SPIRV_Bool:$predicate
1535+
);
1536+
1537+
let results = (outs
1538+
SPIRV_Bool:$result
1539+
);
1540+
1541+
let hasVerifier = 0;
1542+
1543+
let assemblyFormat = [{
1544+
$execution_scope $predicate attr-dict `:` type($result)
1545+
}];
1546+
}
1547+
1548+
// -----
1549+
1550+
def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
1551+
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
1552+
]> {
1553+
let summary = [{
1554+
Evaluates a value for all tangled invocations within the Execution
1555+
scope. The result is true if Value is equal for all tangled invocations
1556+
within the Execution scope. Otherwise, the result is false.
1557+
}];
1558+
1559+
let description = [{
1560+
Result Type must be a Boolean type.
1561+
1562+
Execution is the scope defining the scope restricted tangle affected by
1563+
this command. It must be Subgroup.
1564+
1565+
Value must be a scalar or vector of floating-point type, integer type,
1566+
or Boolean type. The compare operation is based on this type, and if it
1567+
is a floating-point type, an ordered-and-equal compare is used.
1568+
1569+
An invocation will not execute a dynamic instance of this instruction
1570+
(X') until all invocations in its scope restricted tangle have executed
1571+
all dynamic instances that are program-ordered before X'.
1572+
1573+
<!-- End of AutoGen section -->
1574+
1575+
#### Example:
1576+
1577+
```mlir
1578+
%scalar_value = ... : f32
1579+
%vector_value = ... : vector<4xf32>
1580+
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %scalar_value : f32, i1
1581+
%1 = spirv.GroupNonUniformAllEqual <Subgroup> %vector_value : vector<4xf32>, i1
1582+
```
1583+
}];
1584+
1585+
let availability = [
1586+
MinVersion<SPIRV_V_1_3>,
1587+
MaxVersion<SPIRV_V_1_6>,
1588+
Extension<[]>,
1589+
Capability<[SPIRV_C_GroupNonUniformVote]>
1590+
];
1591+
1592+
let arguments = (ins
1593+
SPIRV_ScopeAttr:$execution_scope,
1594+
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value
1595+
);
1596+
1597+
let results = (outs
1598+
SPIRV_Bool:$result
1599+
);
1600+
1601+
1602+
let hasVerifier = 0;
1603+
1604+
let assemblyFormat = [{
1605+
$execution_scope $value attr-dict `:` type($value) `,` type($result)
1606+
}];
1607+
}
1608+
1609+
// -----
1610+
14381611
#endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS

mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,76 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
671671
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
672672
return %0: f32
673673
}
674+
675+
// -----
676+
677+
//===----------------------------------------------------------------------===//
678+
// spirv.GroupNonUniformAll
679+
//===----------------------------------------------------------------------===//
680+
681+
// CHECK-LABEL: @group_non_uniform_all
682+
func.func @group_non_uniform_all(%predicate: i1) -> i1 {
683+
// CHECK: %{{.+}} = spirv.GroupNonUniformAll <Subgroup> %{{.+}} : i1
684+
%0 = spirv.GroupNonUniformAll <Subgroup> %predicate : i1
685+
return %0: i1
686+
}
687+
688+
// -----
689+
690+
func.func @group_non_uniform_all(%predicate: i1) -> i1 {
691+
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
692+
%0 = spirv.GroupNonUniformAll <Device> %predicate : i1
693+
return %0: i1
694+
}
695+
696+
// -----
697+
698+
//===----------------------------------------------------------------------===//
699+
// spirv.GroupNonUniformAny
700+
//===----------------------------------------------------------------------===//
701+
702+
// CHECK-LABEL: @group_non_uniform_any
703+
func.func @group_non_uniform_any(%predicate: i1) -> i1 {
704+
// CHECK: %{{.+}} = spirv.GroupNonUniformAny <Subgroup> %{{.+}} : i1
705+
%0 = spirv.GroupNonUniformAny <Subgroup> %predicate : i1
706+
return %0: i1
707+
}
708+
709+
// -----
710+
711+
func.func @group_non_uniform_any(%predicate: i1) -> i1 {
712+
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
713+
%0 = spirv.GroupNonUniformAny <Device> %predicate : i1
714+
return %0: i1
715+
}
716+
717+
// -----
718+
719+
//===----------------------------------------------------------------------===//
720+
// spirv.GroupNonUniformAllEqual
721+
//===----------------------------------------------------------------------===//
722+
723+
// CHECK-LABEL: @group_non_uniform_all_equal
724+
func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
725+
// CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : f32, i1
726+
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %value : f32, i1
727+
return %0: i1
728+
}
729+
730+
// -----
731+
732+
// CHECK-LABEL: @group_non_uniform_all_equal
733+
func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 {
734+
// CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : vector<4xi32>, i1
735+
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %value : vector<4xi32>, i1
736+
return %0: i1
737+
}
738+
739+
740+
// -----
741+
742+
func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
743+
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
744+
%0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
745+
return %0: i1
746+
}

mlir/test/Target/SPIRV/non-uniform-ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,22 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
124124
%0 = spirv.GroupNonUniformShuffleXor <Subgroup> %val, %id : f32, i32
125125
spirv.ReturnValue %0: f32
126126
}
127+
128+
spirv.func @group_non_uniform_all(%pred: i1) -> i1 "None" {
129+
// CHECK: %{{.+}} = spirv.GroupNonUniformAll <Subgroup> %{{.+}} : i1
130+
%0 = spirv.GroupNonUniformAll <Subgroup> %pred : i1
131+
spirv.ReturnValue %0: i1
132+
}
133+
134+
spirv.func @group_non_uniform_any(%pred: i1) -> i1 "None" {
135+
// CHECK: %{{.+}} = spirv.GroupNonUniformAny <Subgroup> %{{.+}} : i1
136+
%0 = spirv.GroupNonUniformAny <Subgroup> %pred : i1
137+
spirv.ReturnValue %0: i1
138+
}
139+
140+
spirv.func @group_non_uniform_all_equal(%val: vector<4xi32>) -> i1 "None" {
141+
// CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : vector<4xi32>, i1
142+
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %val : vector<4xi32>, i1
143+
spirv.ReturnValue %0: i1
144+
}
127145
}

0 commit comments

Comments
 (0)