Skip to content

[mlir][spirv] Add spirv-to-llvm conversion for group operations #115501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 12, 2024

Conversation

sommerlukas
Copy link
Contributor

Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect.

Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools.

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Lukas Sommer (sommerlukas)

Changes

Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect.

Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools.


Patch is 45.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115501.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+235-5)
  • (added) mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir (+312)
  • (added) mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir (+247)
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ba31936af5510d..3fd78d2b08a6bb 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType) {
+                                              Type resultType,
+                                              bool convergent = true) {
   auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(symbolTable, name));
   if (func)
@@ -1038,7 +1040,9 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
       symbolTable->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes));
   func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
-  func.setConvergent(true);
+  if (convergent) {
+    func.setConvergent(true);
+  }
   func.setNoUnwind(true);
   func.setWillReturn(true);
   return func;
@@ -1046,10 +1050,13 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
 
 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
                                            LLVM::LLVMFuncOp func,
-                                           ValueRange args) {
+                                           ValueRange args,
+                                           bool convergent = true) {
   auto call = builder.create<LLVM::CallOp>(loc, func, args);
   call.setCConv(func.getCConv());
-  call.setConvergentAttr(func.getConvergentAttr());
+  if (convergent) {
+    call.setConvergentAttr(func.getConvergentAttr());
+  }
   call.setNoUnwindAttr(func.getNoUnwindAttr());
   call.setWillReturnAttr(func.getWillReturnAttr());
   return call;
@@ -1089,6 +1096,186 @@ class ControlBarrierPattern
   }
 };
 
+namespace {
+
+StringRef getTypeMangling(Type type, bool isSigned) {
+  return llvm::TypeSwitch<Type, StringRef>(type)
+      .Case<Float16Type>([](auto) { return "Dh"; })
+      .template Case<Float32Type>([](auto) { return "f"; })
+      .template Case<Float64Type>([](auto) { return "d"; })
+      .template Case<IntegerType>([isSigned](IntegerType intTy) {
+        switch (intTy.getWidth()) {
+        case 1:
+          return "b";
+        case 8:
+          return (isSigned) ? "a" : "c";
+        case 16:
+          return (isSigned) ? "s" : "t";
+        case 32:
+          return (isSigned) ? "i" : "j";
+        case 64:
+          return (isSigned) ? "l" : "m";
+        default: {
+          assert(false && "Unsupported integer width");
+          return "";
+        }
+        }
+      })
+      .Default([](auto) {
+        assert(false && "No mangling defined");
+        return "";
+      });
+}
+
+template <typename ReduceOp>
+constexpr StringLiteral getGroupFuncName() {
+  assert(false && "No builtin defined");
+  return "";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
+  return "_Z17__spirv_GroupIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
+  return "_Z17__spirv_GroupFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
+  return "_Z17__spirv_GroupSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
+  return "_Z17__spirv_GroupUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
+  return "_Z17__spirv_GroupFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
+  return "_Z17__spirv_GroupSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
+  return "_Z17__spirv_GroupUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
+  return "_Z17__spirv_GroupFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
+  return "_Z27__spirv_GroupNonUniformIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
+  return "_Z27__spirv_GroupNonUniformFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
+  return "_Z27__spirv_GroupNonUniformIMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
+  return "_Z27__spirv_GroupNonUniformFMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
+  return "_Z27__spirv_GroupNonUniformSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
+  return "_Z27__spirv_GroupNonUniformUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
+  return "_Z27__spirv_GroupNonUniformFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
+  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
+  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
+  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
+  return "_Z33__spirv_GroupNonUniformLogicalAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
+  return "_Z32__spirv_GroupNonUniformLogicalOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
+  return "_Z33__spirv_GroupNonUniformLogicalXorii";
+}
+} // namespace
+
+template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
+class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
+public:
+  using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type retTy = op.getResult().getType();
+    if (!retTy.isIntOrFloat()) {
+      return failure();
+    }
+    SmallString<20> funcName = getGroupFuncName<ReduceOp>();
+    funcName += getTypeMangling(retTy, false);
+
+    Type i32Ty = rewriter.getI32Type();
+    SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
+    if constexpr (NonUniform) {
+      if (adaptor.getClusterSize()) {
+        funcName += "j";
+        paramTypes.push_back(i32Ty);
+      }
+    }
+
+    Operation *symbolTable =
+        op->template getParentWithTrait<OpTrait::SymbolTable>();
+
+    LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
+        symbolTable, funcName, paramTypes, retTy, !NonUniform);
+
+    Location loc = op.getLoc();
+    Value scope = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
+    Value groupOp = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
+    SmallVector<Value> operands{scope, groupOp};
+    operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
+
+    auto call =
+        createSPIRVBuiltinCall(loc, rewriter, func, operands, !NonUniform);
+    rewriter.replaceOp(op, call);
+    return success();
+  }
+};
+
 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
 /// should be reachable for conversion to succeed. The structure of the loop in
 /// LLVM dialect will be the following:
@@ -1722,7 +1909,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       ReturnPattern, ReturnValuePattern,
 
       // Barrier ops
-      ControlBarrierPattern>(patterns.getContext(), typeConverter);
+      ControlBarrierPattern,
+
+      // Group reduction operations
+      GroupReducePattern<spirv::GroupIAddOp>,
+      GroupReducePattern<spirv::GroupFAddOp>,
+      GroupReducePattern<spirv::GroupFMinOp>,
+      GroupReducePattern<spirv::GroupUMinOp>,
+      GroupReducePattern<spirv::GroupSMinOp, /*Signed*/ true>,
+      GroupReducePattern<spirv::GroupFMaxOp>,
+      GroupReducePattern<spirv::GroupUMaxOp>,
+      GroupReducePattern<spirv::GroupSMaxOp, /*Signed*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed*/ true,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed*/ true,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed*/ false,
+                         /*NonUniform*/ true>
+      >(patterns.getContext(), typeConverter);
 
   patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
                                       typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..8c8fc50349e795
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
@@ -0,0 +1,312 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK-LABEL:   llvm.func spir_funccc @_Z17__spirv_GroupSMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupUMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupFMaxiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupSMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupUMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupFMiniif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupFAddiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupIAddiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL:   llvm.func @group_reduce_iadd(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_iadd(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupIAdd <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_fadd(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_reduce_fadd(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFAdd <Workgroup> <Reduce> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_fmin(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_reduce_fmin(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFMin <Workgroup> <Reduce> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_umin(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_umin(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupUMin <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_smin(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_smin(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupSMin <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_fmax(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMaxiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_reduce_fmax(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFMax <Workgroup> <Reduce> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_umax(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_umax(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupUMax <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_smax(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_smax(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupSMax <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_iadd(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_inclusive_scan_iadd(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupIAdd <Workgroup> <InclusiveScan> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_fadd(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_inclusive_scan_fadd(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFAdd <Workgroup> <InclusiveScan> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_fmin(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_inclusive_scan_fmin(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFMin <Workgroup> <InclusiveScan> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_umin(
+//...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir

Author: Lukas Sommer (sommerlukas)

Changes

Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect.

Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools.


Patch is 45.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115501.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+235-5)
  • (added) mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir (+312)
  • (added) mlir/test/Conversion/SPIRVToLLVM/non-uniform-ops-to-llvm.mlir (+247)
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ba31936af5510d..3fd78d2b08a6bb 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType) {
+                                              Type resultType,
+                                              bool convergent = true) {
   auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(symbolTable, name));
   if (func)
@@ -1038,7 +1040,9 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
       symbolTable->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes));
   func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
-  func.setConvergent(true);
+  if (convergent) {
+    func.setConvergent(true);
+  }
   func.setNoUnwind(true);
   func.setWillReturn(true);
   return func;
@@ -1046,10 +1050,13 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
 
 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
                                            LLVM::LLVMFuncOp func,
-                                           ValueRange args) {
+                                           ValueRange args,
+                                           bool convergent = true) {
   auto call = builder.create<LLVM::CallOp>(loc, func, args);
   call.setCConv(func.getCConv());
-  call.setConvergentAttr(func.getConvergentAttr());
+  if (convergent) {
+    call.setConvergentAttr(func.getConvergentAttr());
+  }
   call.setNoUnwindAttr(func.getNoUnwindAttr());
   call.setWillReturnAttr(func.getWillReturnAttr());
   return call;
@@ -1089,6 +1096,186 @@ class ControlBarrierPattern
   }
 };
 
+namespace {
+
+StringRef getTypeMangling(Type type, bool isSigned) {
+  return llvm::TypeSwitch<Type, StringRef>(type)
+      .Case<Float16Type>([](auto) { return "Dh"; })
+      .template Case<Float32Type>([](auto) { return "f"; })
+      .template Case<Float64Type>([](auto) { return "d"; })
+      .template Case<IntegerType>([isSigned](IntegerType intTy) {
+        switch (intTy.getWidth()) {
+        case 1:
+          return "b";
+        case 8:
+          return (isSigned) ? "a" : "c";
+        case 16:
+          return (isSigned) ? "s" : "t";
+        case 32:
+          return (isSigned) ? "i" : "j";
+        case 64:
+          return (isSigned) ? "l" : "m";
+        default: {
+          assert(false && "Unsupported integer width");
+          return "";
+        }
+        }
+      })
+      .Default([](auto) {
+        assert(false && "No mangling defined");
+        return "";
+      });
+}
+
+template <typename ReduceOp>
+constexpr StringLiteral getGroupFuncName() {
+  assert(false && "No builtin defined");
+  return "";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
+  return "_Z17__spirv_GroupIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
+  return "_Z17__spirv_GroupFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
+  return "_Z17__spirv_GroupSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
+  return "_Z17__spirv_GroupUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
+  return "_Z17__spirv_GroupFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
+  return "_Z17__spirv_GroupSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
+  return "_Z17__spirv_GroupUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
+  return "_Z17__spirv_GroupFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
+  return "_Z27__spirv_GroupNonUniformIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
+  return "_Z27__spirv_GroupNonUniformFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
+  return "_Z27__spirv_GroupNonUniformIMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
+  return "_Z27__spirv_GroupNonUniformFMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
+  return "_Z27__spirv_GroupNonUniformSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
+  return "_Z27__spirv_GroupNonUniformUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
+  return "_Z27__spirv_GroupNonUniformFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
+  return "_Z27__spirv_GroupNonUniformFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
+  return "_Z33__spirv_GroupNonUniformBitwiseAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
+  return "_Z32__spirv_GroupNonUniformBitwiseOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
+  return "_Z33__spirv_GroupNonUniformBitwiseXorii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
+  return "_Z33__spirv_GroupNonUniformLogicalAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
+  return "_Z32__spirv_GroupNonUniformLogicalOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
+  return "_Z33__spirv_GroupNonUniformLogicalXorii";
+}
+} // namespace
+
+template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
+class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
+public:
+  using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type retTy = op.getResult().getType();
+    if (!retTy.isIntOrFloat()) {
+      return failure();
+    }
+    SmallString<20> funcName = getGroupFuncName<ReduceOp>();
+    funcName += getTypeMangling(retTy, false);
+
+    Type i32Ty = rewriter.getI32Type();
+    SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
+    if constexpr (NonUniform) {
+      if (adaptor.getClusterSize()) {
+        funcName += "j";
+        paramTypes.push_back(i32Ty);
+      }
+    }
+
+    Operation *symbolTable =
+        op->template getParentWithTrait<OpTrait::SymbolTable>();
+
+    LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
+        symbolTable, funcName, paramTypes, retTy, !NonUniform);
+
+    Location loc = op.getLoc();
+    Value scope = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
+    Value groupOp = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
+    SmallVector<Value> operands{scope, groupOp};
+    operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
+
+    auto call =
+        createSPIRVBuiltinCall(loc, rewriter, func, operands, !NonUniform);
+    rewriter.replaceOp(op, call);
+    return success();
+  }
+};
+
 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
 /// should be reachable for conversion to succeed. The structure of the loop in
 /// LLVM dialect will be the following:
@@ -1722,7 +1909,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       ReturnPattern, ReturnValuePattern,
 
       // Barrier ops
-      ControlBarrierPattern>(patterns.getContext(), typeConverter);
+      ControlBarrierPattern,
+
+      // Group reduction operations
+      GroupReducePattern<spirv::GroupIAddOp>,
+      GroupReducePattern<spirv::GroupFAddOp>,
+      GroupReducePattern<spirv::GroupFMinOp>,
+      GroupReducePattern<spirv::GroupUMinOp>,
+      GroupReducePattern<spirv::GroupSMinOp, /*Signed*/ true>,
+      GroupReducePattern<spirv::GroupFMaxOp>,
+      GroupReducePattern<spirv::GroupUMaxOp>,
+      GroupReducePattern<spirv::GroupSMaxOp, /*Signed*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed*/ true,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed*/ true,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed*/ false,
+                         /*NonUniform*/ true>,
+      GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed*/ false,
+                         /*NonUniform*/ true>
+      >(patterns.getContext(), typeConverter);
 
   patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
                                       typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..8c8fc50349e795
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
@@ -0,0 +1,312 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK-LABEL:   llvm.func spir_funccc @_Z17__spirv_GroupSMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupUMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupFMaxiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupSMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupUMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupFMiniif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupFAddiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK:         llvm.func spir_funccc @_Z17__spirv_GroupIAddiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL:   llvm.func @group_reduce_iadd(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_iadd(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupIAdd <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_fadd(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_reduce_fadd(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFAdd <Workgroup> <Reduce> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_fmin(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_reduce_fmin(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFMin <Workgroup> <Reduce> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_umin(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_umin(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupUMin <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_smin(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_smin(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupSMin <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_fmax(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMaxiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_reduce_fmax(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFMax <Workgroup> <Reduce> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_umax(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_umax(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupUMax <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_reduce_smax(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_reduce_smax(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupSMax <Workgroup> <Reduce> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_iadd(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK:           llvm.return %[[VAL_3]] : i32
+// CHECK:         }
+spirv.func @group_inclusive_scan_iadd(%arg0: i32) -> i32 "None" {
+  %0 = spirv.GroupIAdd <Workgroup> <InclusiveScan> %arg0 : i32
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_fadd(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_inclusive_scan_fadd(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFAdd <Workgroup> <InclusiveScan> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_fmin(
+// CHECK-SAME:                                         %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK:           llvm.return %[[VAL_3]] : f32
+// CHECK:         }
+spirv.func @group_inclusive_scan_fmin(%arg0: f32) -> f32 "None" {
+  %0 = spirv.GroupFMin <Workgroup> <InclusiveScan> %arg0 : f32
+  spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL:   llvm.func @group_inclusive_scan_umin(
+//...
[truncated]

Copy link

github-actions bot commented Nov 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@FMarno FMarno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just nits really

Copy link
Contributor

@victor-eds victor-eds left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but nits

Signed-off-by: Lukas Sommer <[email protected]>
@kuhar
Copy link
Member

kuhar commented Nov 11, 2024

(no objections from me)

@victor-eds victor-eds merged commit 6ade03d into llvm:main Nov 12, 2024
8 checks passed
@sommerlukas sommerlukas deleted the lower-group-ops-to-llvm branch November 12, 2024 09:09
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…#115501)

Lowering for some of the uniform and non-uniform group operations
defined in section 3.52.21 of the SPIR-V specification from SPIR-V
dialect to LLVM dialect.

Similar to llvm#111864, lower the operations to builtin functions understood
by SPIR-V tools.

---------

Signed-off-by: Lukas Sommer <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants