Skip to content

[mlir][spirv] Add convergent attribute to builtin #122131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

sommerlukas
Copy link
Contributor

Add the convergent attribute to builtin functions and builtin function calls when lowering SPIR-V non-uniform group functions to LLVM dialect.

Add the `convergent` attribute to builtin functions and builtin function
calls when lowering SPIR-V non-uniform group functions to LLVM dialect.

Signed-off-by: Lukas Sommer <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Lukas Sommer (sommerlukas)

Changes

Add the convergent attribute to builtin functions and builtin function calls when lowering SPIR-V non-uniform group functions to LLVM dialect.


Patch is 21.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122131.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+4-5)
  • (modified) mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir (+36-36)
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index b11511f21d03d4..e79005f208c3d9 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1028,8 +1028,7 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType,
-                                              bool convergent = true) {
+                                              Type resultType) {
   auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(symbolTable, name));
   if (func)
@@ -1040,7 +1039,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
       symbolTable->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes));
   func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
-  func.setConvergent(convergent);
+  func.setConvergent(true);
   func.setNoUnwind(true);
   func.setWillReturn(true);
   return func;
@@ -1253,8 +1252,8 @@ class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
     Operation *symbolTable =
         op->template getParentWithTrait<OpTrait::SymbolTable>();
 
-    LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
-        symbolTable, funcName, paramTypes, retTy, !NonUniform);
+    LLVM::LLVMFuncOp func =
+        lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
 
     Location loc = op.getLoc();
     Value scope = rewriter.create<LLVM::ConstantOp>(
diff --git a/mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir
index e81048792c45de..ab174ba2b41e48 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir
@@ -2,30 +2,30 @@
 
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 
-// CHECK-LABEL:   llvm.func spir_funccc @_Z33__spirv_GroupNonUniformLogicalXoriib(i32, i32, i1) -> i1 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z32__spirv_GroupNonUniformLogicalOriib(i32, i32, i1) -> i1 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z33__spirv_GroupNonUniformLogicalAndiib(i32, i32, i1) -> i1 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z33__spirv_GroupNonUniformBitwiseXoriij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z32__spirv_GroupNonUniformBitwiseOriij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z33__spirv_GroupNonUniformBitwiseAndiij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformSMaxiijj(i32, i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMaxiif(i32, i32, f32) -> f32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformUMaxiij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformSMaxiij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMiniifj(i32, i32, f32, i32) -> f32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMiniif(i32, i32, f32) -> f32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformUMiniij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformSMiniij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMuliif(i32, i32, f32) -> f32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformIMuliijj(i32, i32, i32, i32) -> i32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFAddiifj(i32, i32, f32, i32) -> f32 attributes {no_unwind, will_return}
-// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(i32, i32, i32) -> i32 attributes {no_unwind, will_return}
+// CHECK-LABEL:   llvm.func spir_funccc @_Z33__spirv_GroupNonUniformLogicalXoriib(i32, i32, i1) -> i1 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z32__spirv_GroupNonUniformLogicalOriib(i32, i32, i1) -> i1 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z33__spirv_GroupNonUniformLogicalAndiib(i32, i32, i1) -> i1 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z33__spirv_GroupNonUniformBitwiseXoriij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z32__spirv_GroupNonUniformBitwiseOriij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z33__spirv_GroupNonUniformBitwiseAndiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformSMaxiijj(i32, i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMaxiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformUMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformSMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMiniifj(i32, i32, f32, i32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMiniif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformUMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformSMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFMuliif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformIMuliijj(i32, i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformFAddiifj(i32, i32, f32, i32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
 
 // CHECK-LABEL:   llvm.func @non_uniform_iadd(
 // CHECK-SAME:                                %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_iadd(%arg0: i32) -> i32 "None" {
@@ -38,7 +38,7 @@ spirv.func @non_uniform_iadd(%arg0: i32) -> i32 "None" {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(16 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFAddiifj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {no_unwind, will_return} : (i32, i32, f32, i32) -> f32
+// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFAddiifj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {convergent, no_unwind, will_return} : (i32, i32, f32, i32) -> f32
 // CHECK:           llvm.return %[[VAL_4]] : f32
 // CHECK:         }
 spirv.func @non_uniform_fadd(%arg0: f32) -> f32 "None" {
@@ -52,7 +52,7 @@ spirv.func @non_uniform_fadd(%arg0: f32) -> f32 "None" {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(16 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIMuliijj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {no_unwind, will_return} : (i32, i32, i32, i32) -> i32
+// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIMuliijj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {convergent, no_unwind, will_return} : (i32, i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_4]] : i32
 // CHECK:         }
 spirv.func @non_uniform_imul(%arg0: i32) -> i32 "None" {
@@ -65,7 +65,7 @@ spirv.func @non_uniform_imul(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMuliif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMuliif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
 // CHECK:           llvm.return %[[VAL_3]] : f32
 // CHECK:         }
 spirv.func @non_uniform_fmul(%arg0: f32) -> f32 "None" {
@@ -77,7 +77,7 @@ spirv.func @non_uniform_fmul(%arg0: f32) -> f32 "None" {
 // CHECK-SAME:                                %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformSMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformSMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_smin(%arg0: i32) -> i32 "None" {
@@ -89,7 +89,7 @@ spirv.func @non_uniform_smin(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformUMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformUMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_umin(%arg0: i32) -> i32 "None" {
@@ -101,7 +101,7 @@ spirv.func @non_uniform_umin(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
 // CHECK:           llvm.return %[[VAL_3]] : f32
 // CHECK:         }
 spirv.func @non_uniform_fmin(%arg0: f32) -> f32 "None" {
@@ -114,7 +114,7 @@ spirv.func @non_uniform_fmin(%arg0: f32) -> f32 "None" {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(16 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMiniifj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {no_unwind, will_return} : (i32, i32, f32, i32) -> f32
+// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMiniifj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {convergent, no_unwind, will_return} : (i32, i32, f32, i32) -> f32
 // CHECK:           llvm.return %[[VAL_4]] : f32
 // CHECK:         }
 spirv.func @non_uniform_fmin_cluster(%arg0: f32) -> f32 "None" {
@@ -127,7 +127,7 @@ spirv.func @non_uniform_fmin_cluster(%arg0: f32) -> f32 "None" {
 // CHECK-SAME:                                %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformSMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformSMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_smax(%arg0: i32) -> i32 "None" {
@@ -139,7 +139,7 @@ spirv.func @non_uniform_smax(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformUMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformUMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_umax(%arg0: i32) -> i32 "None" {
@@ -151,7 +151,7 @@ spirv.func @non_uniform_umax(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMaxiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMaxiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
 // CHECK:           llvm.return %[[VAL_3]] : f32
 // CHECK:         }
 spirv.func @non_uniform_fmax(%arg0: f32) -> f32 "None" {
@@ -164,7 +164,7 @@ spirv.func @non_uniform_fmax(%arg0: f32) -> f32 "None" {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(16 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformSMaxiijj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {no_unwind, will_return} : (i32, i32, i32, i32) -> i32
+// CHECK:           %[[VAL_4:.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformSMaxiijj(%[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]) {convergent, no_unwind, will_return} : (i32, i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_4]] : i32
 // CHECK:         }
 spirv.func @non_uniform_smax_cluster(%arg0: i32) -> i32 "None" {
@@ -177,7 +177,7 @@ spirv.func @non_uniform_smax_cluster(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                       %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z33__spirv_GroupNonUniformBitwiseAndiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z33__spirv_GroupNonUniformBitwiseAndiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_bitwise_and(%arg0: i32) -> i32 "None" {
@@ -189,7 +189,7 @@ spirv.func @non_uniform_bitwise_and(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                      %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z32__spirv_GroupNonUniformBitwiseOriij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z32__spirv_GroupNonUniformBitwiseOriij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_bitwise_or(%arg0: i32) -> i32 "None" {
@@ -201,7 +201,7 @@ spirv.func @non_uniform_bitwise_or(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                       %[[VAL_0:.*]]: i32) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z33__spirv_GroupNonUniformBitwiseXoriij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z33__spirv_GroupNonUniformBitwiseXoriij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
 // CHECK:         }
 spirv.func @non_uniform_bitwise_xor(%arg0: i32) -> i32 "None" {
@@ -213,7 +213,7 @@ spirv.func @non_uniform_bitwise_xor(%arg0: i32) -> i32 "None" {
 // CHECK-SAME:                                       %[[VAL_0:.*]]: i1) -> i1 {
 // CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z33__spirv_GroupNonUniformLogicalAndiib(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {no_unwind, will_return} : (i32, i32, i1) -> i1
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z33__spirv_GroupNonUniformLogicalAndiib(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i1) -> i1
 // CHECK:           llvm.return %[[VAL_3]] : i1
 // CHECK:         }
 spirv.func @non_uniform_logical_and(%arg0: i1) -> i1 "None" {
@@ -225,7 +225,7 @@ spirv.func @non_uniform_logical_and(%arg0: i1) -> i1 "None" {
 // CHECK-SAME:               ...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point to some documentation that justifies setting this unconditionally?

@sommerlukas
Copy link
Contributor Author

Can you point to some documentation that justifies setting this unconditionally?

The overview defines a convergent operation as:

A convergent operation involves inter-thread communication or synchronization that occurs outside of the memory model, where the set of threads which participate in communication is implicitly affected by control flow.

Based on my understanding, the non-uniform SPIR-V group operations fulfill this definition, because they (a) involve inter-thread communication (reduction) and (b) the set of threads participating is affected by control flow.

@kuhar
Copy link
Member

kuhar commented Jan 8, 2025

I thought that this is what the old code used to do, no? Now lookupOrCreateSPIRVFn marks any function as convergent -- I'd expect some code comments that say why this is.

@sommerlukas
Copy link
Contributor Author

I thought that this is what the old code used to do, no? Now lookupOrCreateSPIRVFn marks any function as convergent -- I'd expect some code comments that say why this is.

I'm not sure I understand what you mean here. My understanding is that the non-uniform SPIR-V functions fulfill the definition of a convergent operation and should therefore be marked as such, which the new version of lookupOrCreateSPIRVFn now does.

Is that what you mean?

@kuhar
Copy link
Member

kuhar commented Jan 8, 2025

My understanding is that the non-uniform SPIR-V functions fulfill the definition of a convergent operation and should therefore be marked as such

This I agree with.

the new version of lookupOrCreateSPIRVFn now does.

This I don't follow. Looking at the previous implementation of this code, I'd think that this was already being handled by

    LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
        symbolTable, funcName, paramTypes, retTy, !NonUniform);

I'm not very familiar with this code, and it seems odd to me that a very generic-sounding function like lookupOrCreateSPIRVFn would set this attribute unconditionally for any builtin (not just non-uniform).

@victor-eds
Copy link
Contributor

I see what @kuhar means. Maybe simply doing:

    LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
        symbolTable, funcName, paramTypes, retTy, /*isConvergent=*/true);

and keeping lookupOrCreateSPIRVFn's definition as is is better.

@sommerlukas
Copy link
Contributor Author

I see what @kuhar means. Maybe simply doing:

    LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
        symbolTable, funcName, paramTypes, retTy, /*isConvergent=*/true);

and keeping lookupOrCreateSPIRVFn's definition as is is better.

Yeah, sorry for misunderstanding your point yesterday @kuhar.

I had only introduced the convergent parameter to lookupOrCreateSPIRVFn function in the previous PR to add lowering for group operations here: #115501

The use of this function that I changed in this PR was the only use of that function that specifies a value for convergent.

The only other use of of lookupOrCreateSPIRVFn already unconditionally sets true through the default value: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp#L1080

That's why I figured I could remove the convergent parameter again.

However, if you and @victor-eds prefer to keep the parameter, I can bring it back, no problem.

WDYT?

@victor-eds
Copy link
Contributor

I see what @kuhar means. Maybe simply doing:

    LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
        symbolTable, funcName, paramTypes, retTy, /*isConvergent=*/true);

and keeping lookupOrCreateSPIRVFn's definition as is is better.

Yeah, sorry for misunderstanding your point yesterday @kuhar.

I had only introduced the convergent parameter to lookupOrCreateSPIRVFn function in the previous PR to add lowering for group operations here: #115501

The use of this function that I changed in this PR was the only use of that function that specifies a value for convergent.

The only other use of of lookupOrCreateSPIRVFn already unconditionally sets true through the default value: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp#L1080

That's why I figured I could remove the convergent parameter again.

However, if you and @victor-eds prefer to keep the parameter, I can bring it back, no problem.

WDYT?

I'd rather keep the option as we may have to bring it back again in the future.

Signed-off-by: Lukas Sommer <[email protected]>
@sommerlukas
Copy link
Contributor Author

I'd rather keep the option as we may have to bring it back again in the future.

Ok, done in the latest commit.

@sommerlukas sommerlukas merged commit 4adeb6c into llvm:main Jan 10, 2025
8 checks passed
@sommerlukas sommerlukas deleted the add-convergent-attribute-non-uniform branch January 10, 2025 08:15
BaiXilin pushed a commit to BaiXilin/llvm-fix-vnni-instr-types that referenced this pull request Jan 12, 2025
Add the `convergent` attribute to builtin functions and builtin function
calls when lowering SPIR-V non-uniform group functions to LLVM dialect.

---------

Signed-off-by: Lukas Sommer <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants