Skip to content

[mlir][spirv] Add instruction OpGroupNonUniformRotateKHR #133428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 3, 2025

Conversation

Hsiangkai
Copy link
Contributor

Add an instruction under the extension SPV_KHR_subgroup_rotate.

The specification for the extension is here:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html

Add an instruction under the extension SPV_KHR_subgroup_rotate.

The specification for the extension is here:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html
@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

Changes

Add an instruction under the extension SPV_KHR_subgroup_rotate.

The specification for the extension is here:
https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html


Full diff: https://github.com/llvm/llvm-project/pull/133428.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td (+75)
  • (modified) mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir (+23)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d5359da2a590e..cd5d201c3d5da 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4489,6 +4489,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
 def SPIRV_OC_OpUDot                           : I32EnumAttrCase<"OpUDot", 4451>;
@@ -4598,7 +4599,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
+      SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
       SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
       SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 98e435c18d3d7..f195adfc0e73d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -1361,4 +1361,79 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
 
 // -----
 
+def SPIRV_GroupNonUniformRotateKHR : SPIRV_Op<"GroupNonUniformRotateKHR", []> {
+  let summary = [{
+    Rotate values across invocations within a subgroup.
+  }];
+
+  let description = [{
+    Return the Value of the invocation whose id within the group is calculated
+    as follows:
+
+    LocalId = SubgroupLocalInvocationId if Execution is Subgroup or
+              LocalInvocationId if Execution is Workgroup
+    RotationGroupSize = ClusterSize when ClusterSize is present, otherwise
+    RotationGroupSize = SubgroupMaxSize if the Kernel capability is declared
+                        and SubgroupSize if not.
+    Invocation ID = ( (LocalId + Delta) & (RotationGroupSize - 1) ) +
+                    (LocalId & ~(RotationGroupSize - 1))
+
+    Result Type must be a scalar or vector of floating-point type, integer
+    type, or Boolean type.
+
+    Execution is a Scope. It must be either Workgroup or Subgroup.
+
+    The type of Value must be the same as Result Type.
+
+    Delta must be a scalar of integer type, whose Signedness operand is 0.
+    Delta must be dynamically uniform within Execution.
+
+    Delta is treated as unsigned and the resulting value is undefined if the
+    selected lane is inactive.
+
+    ClusterSize is the size of cluster to use. ClusterSize must be a scalar of
+    integer type, whose Signedness operand is 0. ClusterSize must come from a
+    constant instruction. Behavior is undefined unless ClusterSize is at least
+    1 and a power of 2. If ClusterSize is greater than the declared
+    SubGroupSize, executing this instruction results in undefined behavior.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %four = spirv.Constant 4 : i32
+    %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
+    %1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
+         clustersize(%four) : f32, i32, i32 -> f32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_3>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_GroupNonUniformRotateKHR]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScopeAttr:$execution_scope,
+    SPIRV_Type:$value,
+    SPIRV_Integer:$delta,
+    Optional<SPIRV_Integer>:$cluster_size
+  );
+
+  let results = (outs
+    SPIRV_Type:$result
+  );
+
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    $execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
+  }];
+}
+
+// -----
+
 #endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 60ae1584d29fb..60b99d51363e9 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -604,3 +604,26 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
   %0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
   return %0: i32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformRotateKHR
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
+  %0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
+  return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @group_non_uniform_rotate_khr
+func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
+  // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
+  %four = spirv.Constant 4 : i32
+  %0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
+  return %0: f32
+}

@kuhar
Copy link
Member

kuhar commented Mar 28, 2025

@Hsiangkai can you merge this on your own or do you want me to do it?

@IgWod-IMG
Copy link
Contributor

It would be nice to have some verification or at least constrain operands and result types. For example:

Result Type must be a scalar or vector of floating-point type, integer
type, or Boolean type

- SPIRV_Type:$result
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result

(n.b. It would be nice to have SPIRV_ScalarOrVectorOf that takes a list of types but that's outside this PR)

@Hsiangkai
Copy link
Contributor Author

It would be nice to have some verification or at least constrain operands and result types. For example:

Result Type must be a scalar or vector of floating-point type, integer
type, or Boolean type

- SPIRV_Type:$result
+ AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result

(n.b. It would be nice to have SPIRV_ScalarOrVectorOf that takes a list of types but that's outside this PR)

Thanks for your review.
I have added operand and result constraints and added operator verifier to check the validity of values. I also added several test cases to ensure the checking works.

if (getDelta().getType().isSignedInteger())
return emitOpError("delta must be a singless/unsigned integer");

auto clusterSizeVal = getClusterSize();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the full type here since the type is not obvious based on the RHS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return emitOpError("delta must be a singless/unsigned integer");

auto clusterSizeVal = getClusterSize();
if (clusterSizeVal) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's only used inside this if, you can do:

Suggested change
if (clusterSizeVal) {
if (TheType clusterSizeVal = getClusterSize()) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some minor (optional) suggestion. Thank you for adding the verification!

if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");

if (getDelta().getType().isSignedInteger())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this check be avoided if delta is set to be SPIRV_SignlessOrUnsignedInt? NB This is a good place to look at what types are already defined:

But I'm not going to lie, I sometimes personally get quite confused looking for the correct type :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


auto clusterSizeVal = getClusterSize();
if (clusterSizeVal) {
if (clusterSizeVal.getType().isSignedInteger())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


let arguments = (ins
SPIRV_ScopeAttr:$execution_scope,
SPIRV_Type:$value,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we spell the type of the $value here to be AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>, which is the same type as a result?

I know the spec says nothing about the type directly and only defines it in terms of the result:

The type of Value must be the same as Result Type.

But I had this discussion with @kuhar in the past: #124571 (comment) and it seems that we agreed that being more specific is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (with a small nitpick)

if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");

if (TypedValue<Type> clusterSizeVal = getClusterSize()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was suggested to me before to use Value instead of TypedValue<Type> in the situation like this: #126555 (comment) For what's worth I cannot see TypedValue<Type> being used anywhere in the dialect code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. I have updated it. I will merge this patch after all tests pass.

@Hsiangkai Hsiangkai merged commit 2e7ed78 into llvm:main Apr 3, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants