-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add instruction OpGroupNonUniformRotateKHR #133428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add an instruction under the extension SPV_KHR_subgroup_rotate. The specification for the extension is here: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Hsiangkai Wang (Hsiangkai) ChangesAdd an instruction under the extension SPV_KHR_subgroup_rotate. The specification for the extension is here: Full diff: https://github.com/llvm/llvm-project/pull/133428.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d5359da2a590e..cd5d201c3d5da 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4489,6 +4489,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>;
@@ -4598,7 +4599,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
- SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+ SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
+ SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 98e435c18d3d7..f195adfc0e73d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1361,4 +1361,79 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
// -----
+def SPIRV_GroupNonUniformRotateKHR : SPIRV_Op<"GroupNonUniformRotateKHR", []> {
+ let summary = [{
+ Rotate values across invocations within a subgroup.
+ }];
+
+ let description = [{
+ Return the Value of the invocation whose id within the group is calculated
+ as follows:
+
+ LocalId = SubgroupLocalInvocationId if Execution is Subgroup or
+ LocalInvocationId if Execution is Workgroup
+ RotationGroupSize = ClusterSize when ClusterSize is present, otherwise
+ RotationGroupSize = SubgroupMaxSize if the Kernel capability is declared
+ and SubgroupSize if not.
+ Invocation ID = ( (LocalId + Delta) & (RotationGroupSize - 1) ) +
+ (LocalId & ~(RotationGroupSize - 1))
+
+ Result Type must be a scalar or vector of floating-point type, integer
+ type, or Boolean type.
+
+ Execution is a Scope. It must be either Workgroup or Subgroup.
+
+ The type of Value must be the same as Result Type.
+
+ Delta must be a scalar of integer type, whose Signedness operand is 0.
+ Delta must be dynamically uniform within Execution.
+
+ Delta is treated as unsigned and the resulting value is undefined if the
+ selected lane is inactive.
+
+ ClusterSize is the size of cluster to use. ClusterSize must be a scalar of
+ integer type, whose Signedness operand is 0. ClusterSize must come from a
+ constant instruction. Behavior is undefined unless ClusterSize is at least
+ 1 and a power of 2. If ClusterSize is greater than the declared
+ SubGroupSize, executing this instruction results in undefined behavior.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %four = spirv.Constant 4 : i32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
+ %1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
+ clustersize(%four) : f32, i32, i32 -> f32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_3>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_GroupNonUniformRotateKHR]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScopeAttr:$execution_scope,
+ SPIRV_Type:$value,
+ SPIRV_Integer:$delta,
+ Optional<SPIRV_Integer>:$cluster_size
+ );
+
+ let results = (outs
+ SPIRV_Type:$result
+ );
+
+ let hasVerifier = 0;
+
+ let assemblyFormat = [{
+ $execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
+ }];
+}
+
+// -----
+
#endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 60ae1584d29fb..60b99d51363e9 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -604,3 +604,26 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformRotateKHR
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
+ %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
+ return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+ // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
+ %four = spirv.Constant 4 : i32
+ %0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+ return %0: f32
+}
|
@Hsiangkai can you merge this on your own or do you want me to do it? |
It would be nice to have some verification or at least constrain operands and result types. For example:
- SPIRV_Type:$result
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result (n.b. It would be nice to have |
Thanks for your review. |
if (getDelta().getType().isSignedInteger()) | ||
return emitOpError("delta must be a singless/unsigned integer"); | ||
|
||
auto clusterSizeVal = getClusterSize(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the full type here since the type is not obvious based on the RHS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return emitOpError("delta must be a singless/unsigned integer"); | ||
|
||
auto clusterSizeVal = getClusterSize(); | ||
if (clusterSizeVal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's only used inside this if, you can do:
if (clusterSizeVal) { | |
if (TheType clusterSizeVal = getClusterSize()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some minor (optional) suggestion. Thank you for adding the verification!
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | ||
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | ||
|
||
if (getDelta().getType().isSignedInteger()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this check be avoided if delta
is set to be SPIRV_SignlessOrUnsignedInt
? NB This is a good place to look at what types are already defined:
// SPIR-V type definitions |
But I'm not going to lie, I sometimes personally get quite confused looking for the correct type :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
auto clusterSizeVal = getClusterSize(); | ||
if (clusterSizeVal) { | ||
if (clusterSizeVal.getType().isSignedInteger()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
let arguments = (ins | ||
SPIRV_ScopeAttr:$execution_scope, | ||
SPIRV_Type:$value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we spell the type of the $value
here to be AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>
, which is the same type as a result?
I know the spec says nothing about the type directly and only defines it in terms of the result:
The type of Value must be the same as Result Type.
But I had this discussion with @kuhar in the past: #124571 (comment) and it seems that we agreed that being more specific is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (with a small nitpick)
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | ||
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | ||
|
||
if (TypedValue<Type> clusterSizeVal = getClusterSize()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was suggested to me before to use Value
instead of TypedValue<Type>
in the situation like this: #126555 (comment) For what's worth I cannot see TypedValue<Type>
being used anywhere in the dialect code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your review. I have updated it. I will merge this patch after all tests pass.
Add an instruction under the extension SPV_KHR_subgroup_rotate.
The specification for the extension is here:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html