Skip to content

[mlir][spirv] Add GroupNonUniformVote instructions #141294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4464,6 +4464,9 @@ def SPIRV_OC_OpGroupSMax : I32EnumAttrCase<"OpGroupSMax", 2
def SPIRV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
def SPIRV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPIRV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
def SPIRV_OC_OpGroupNonUniformAllEqual : I32EnumAttrCase<"OpGroupNonUniformAllEqual", 336>;
def SPIRV_OC_OpGroupNonUniformBroadcast : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
def SPIRV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPIRV_OC_OpGroupNonUniformBallotBitCount : I32EnumAttrCase<"OpGroupNonUniformBallotBitCount", 342>;
Expand All @@ -4489,8 +4492,8 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>;
def SPIRV_OC_OpSUDot : I32EnumAttrCase<"OpSUDot", 4452>;
Expand Down Expand Up @@ -4581,11 +4584,13 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable,
SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, SPIRV_OC_OpGroupFAdd,
SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, SPIRV_OC_OpGroupSMin,
SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, SPIRV_OC_OpGroupSMax,
SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed, SPIRV_OC_OpGroupNonUniformElect,
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
SPIRV_OC_OpGroupNonUniformBallotBitCount,
SPIRV_OC_OpGroupNonUniformBallotFindLSB,
Expand All @@ -4599,19 +4604,18 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT,
SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL,
SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR,
SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL,
SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL,
SPIRV_OC_OpControlBarrierWaitINTEL, SPIRV_OC_OpGroupIMulKHR,
SPIRV_OC_OpGroupFMulKHR
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR,
SPIRV_OC_OpCooperativeMatrixMulAddKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
]>;

// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
Expand Down
173 changes: 173 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1435,4 +1435,177 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [

// -----

def SPIRV_GroupNonUniformAllOp : SPIRV_Op<"GroupNonUniformAll", [
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
]> {
let summary = [{
Evaluates a predicate for all tangled invocations within the Execution
scope, resulting in true if predicate evaluates to true for all tangled
invocations within the Execution scope, otherwise the result is false.
}];

let description = [{
Result Type must be a Boolean type.

Execution is the scope defining the scope restricted tangle affected by
this command. It must be Subgroup.

Predicate must be a Boolean type.

An invocation will not execute a dynamic instance of this instruction
(X') until all invocations in its scope restricted tangle have executed
all dynamic instances that are program-ordered before X'.

<!-- End of AutoGen section -->

#### Example:

```mlir
%predicate = ... : i1
%0 = spirv.GroupNonUniformAll "Subgroup" %predicate : i1
```
}];

let availability = [
MinVersion<SPIRV_V_1_3>,
MaxVersion<SPIRV_V_1_6>,
Extension<[]>,
Capability<[SPIRV_C_GroupNonUniformVote]>
];

let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
SPIRV_Bool:$predicate
);

let results = (outs
SPIRV_Bool:$result
);

let hasVerifier = 0;

let assemblyFormat = [{
$execution_scope $predicate attr-dict `:` type($result)
}];
}

// -----

def SPIRV_GroupNonUniformAnyOp : SPIRV_Op<"GroupNonUniformAny", [
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
]> {
let summary = [{
Evaluates a predicate for all tangled invocations within the Execution
scope, resulting in true if predicate evaluates to true for any tangled
invocations within the Execution scope, otherwise the result is false.
}];

let description = [{
Result Type must be a Boolean type.

Execution is the scope defining the scope restricted tangle affected by
this command. It must be Subgroup.

Predicate must be a Boolean type.

An invocation will not execute a dynamic instance of this instruction
(X') until all invocations in its scope restricted tangle have executed
all dynamic instances that are program-ordered before X'.

<!-- End of AutoGen section -->

#### Example:

```mlir
%predicate = ... : i1
%0 = spirv.GroupNonUniformAny "Subgroup" %predicate : i1
```
}];

let availability = [
MinVersion<SPIRV_V_1_3>,
MaxVersion<SPIRV_V_1_6>,
Extension<[]>,
Capability<[SPIRV_C_GroupNonUniformVote]>
];

let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
SPIRV_Bool:$predicate
);

let results = (outs
SPIRV_Bool:$result
);

let hasVerifier = 0;

let assemblyFormat = [{
$execution_scope $predicate attr-dict `:` type($result)
}];
}

// -----

def SPIRV_GroupNonUniformAllEqualOp : SPIRV_Op<"GroupNonUniformAllEqual", [
SPIRV_ExecutionScopeAttrIs<"execution_scope", "Subgroup">
]> {
let summary = [{
Evaluates a value for all tangled invocations within the Execution
scope. The result is true if Value is equal for all tangled invocations
within the Execution scope. Otherwise, the result is false.
}];

let description = [{
Result Type must be a Boolean type.

Execution is the scope defining the scope restricted tangle affected by
this command. It must be Subgroup.

Value must be a scalar or vector of floating-point type, integer type,
or Boolean type. The compare operation is based on this type, and if it
is a floating-point type, an ordered-and-equal compare is used.

An invocation will not execute a dynamic instance of this instruction
(X') until all invocations in its scope restricted tangle have executed
all dynamic instances that are program-ordered before X'.

<!-- End of AutoGen section -->

#### Example:

```mlir
%scalar_value = ... : f32
%vector_value = ... : vector<4xf32>
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %scalar_value : f32, i1
%1 = spirv.GroupNonUniformAllEqual <Subgroup> %vector_value : vector<4xf32>, i1
```
}];

let availability = [
MinVersion<SPIRV_V_1_3>,
MaxVersion<SPIRV_V_1_6>,
Extension<[]>,
Capability<[SPIRV_C_GroupNonUniformVote]>
];

let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value
);

let results = (outs
SPIRV_Bool:$result
);


let hasVerifier = 0;

let assemblyFormat = [{
$execution_scope $value attr-dict `:` type($value) `,` type($result)
}];
}

// -----

#endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
73 changes: 73 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,76 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
return %0: f32
}

// -----

//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformAll
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @group_non_uniform_all
func.func @group_non_uniform_all(%predicate: i1) -> i1 {
// CHECK: %{{.+}} = spirv.GroupNonUniformAll <Subgroup> %{{.+}} : i1
%0 = spirv.GroupNonUniformAll <Subgroup> %predicate : i1
return %0: i1
}

// -----

func.func @group_non_uniform_all(%predicate: i1) -> i1 {
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
%0 = spirv.GroupNonUniformAll <Device> %predicate : i1
return %0: i1
}

// -----

//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformAny
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @group_non_uniform_any
func.func @group_non_uniform_any(%predicate: i1) -> i1 {
// CHECK: %{{.+}} = spirv.GroupNonUniformAny <Subgroup> %{{.+}} : i1
%0 = spirv.GroupNonUniformAny <Subgroup> %predicate : i1
return %0: i1
}

// -----

func.func @group_non_uniform_any(%predicate: i1) -> i1 {
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
%0 = spirv.GroupNonUniformAny <Device> %predicate : i1
return %0: i1
}

// -----

//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformAllEqual
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @group_non_uniform_all_equal
func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
// CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : f32, i1
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %value : f32, i1
return %0: i1
}

// -----

// CHECK-LABEL: @group_non_uniform_all_equal
func.func @group_non_uniform_all_equal(%value: vector<4xi32>) -> i1 {
// CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : vector<4xi32>, i1
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %value : vector<4xi32>, i1
return %0: i1
}


// -----

func.func @group_non_uniform_all_equal(%value: f32) -> i1 {
// expected-error @+1 {{execution_scope must be Scope of value Subgroup}}
%0 = spirv.GroupNonUniformAllEqual <Device> %value : f32, i1
return %0: i1
}
18 changes: 18 additions & 0 deletions mlir/test/Target/SPIRV/non-uniform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,22 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.GroupNonUniformShuffleXor <Subgroup> %val, %id : f32, i32
spirv.ReturnValue %0: f32
}

spirv.func @group_non_uniform_all(%pred: i1) -> i1 "None" {
// CHECK: %{{.+}} = spirv.GroupNonUniformAll <Subgroup> %{{.+}} : i1
%0 = spirv.GroupNonUniformAll <Subgroup> %pred : i1
spirv.ReturnValue %0: i1
}

spirv.func @group_non_uniform_any(%pred: i1) -> i1 "None" {
// CHECK: %{{.+}} = spirv.GroupNonUniformAny <Subgroup> %{{.+}} : i1
%0 = spirv.GroupNonUniformAny <Subgroup> %pred : i1
spirv.ReturnValue %0: i1
}

spirv.func @group_non_uniform_all_equal(%val: vector<4xi32>) -> i1 "None" {
// CHECK: %{{.+}} = spirv.GroupNonUniformAllEqual <Subgroup> %{{.+}} : vector<4xi32>, i1
%0 = spirv.GroupNonUniformAllEqual <Subgroup> %val : vector<4xi32>, i1
spirv.ReturnValue %0: i1
}
}