-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add spirv-to-llvm conversion for group operations #115501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Lukas Sommer <[email protected]>
@llvm/pr-subscribers-mlir-spirv Author: Lukas Sommer (sommerlukas) ChangesLowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect. Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools. Patch is 45.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115501.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ba31936af5510d..3fd78d2b08a6bb 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType) {
+ Type resultType,
+ bool convergent = true) {
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(symbolTable, name));
if (func)
@@ -1038,7 +1040,9 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
- func.setConvergent(true);
+ if (convergent) {
+ func.setConvergent(true);
+ }
func.setNoUnwind(true);
func.setWillReturn(true);
return func;
@@ -1046,10 +1050,13 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
LLVM::LLVMFuncOp func,
- ValueRange args) {
+ ValueRange args,
+ bool convergent = true) {
auto call = builder.create<LLVM::CallOp>(loc, func, args);
call.setCConv(func.getCConv());
- call.setConvergentAttr(func.getConvergentAttr());
+ if (convergent) {
+ call.setConvergentAttr(func.getConvergentAttr());
+ }
call.setNoUnwindAttr(func.getNoUnwindAttr());
call.setWillReturnAttr(func.getWillReturnAttr());
return call;
@@ -1089,6 +1096,186 @@ class ControlBarrierPattern
}
};
+namespace {
+
+StringRef getTypeMangling(Type type, bool isSigned) {
+ return llvm::TypeSwitch<Type, StringRef>(type)
+ .Case<Float16Type>([](auto) { return "Dh"; })
+ .template Case<Float32Type>([](auto) { return "f"; })
+ .template Case<Float64Type>([](auto) { return "d"; })
+ .template Case<IntegerType>([isSigned](IntegerType intTy) {
+ switch (intTy.getWidth()) {
+ case 1:
+ return "b";
+ case 8:
+ return (isSigned) ? "a" : "c";
+ case 16:
+ return (isSigned) ? "s" : "t";
+ case 32:
+ return (isSigned) ? "i" : "j";
+ case 64:
+ return (isSigned) ? "l" : "m";
+ default: {
+ assert(false && "Unsupported integer width");
+ return "";
+ }
+ }
+ })
+ .Default([](auto) {
+ assert(false && "No mangling defined");
+ return "";
+ });
+}
+
+template <typename ReduceOp>
+constexpr StringLiteral getGroupFuncName() {
+ assert(false && "No builtin defined");
+ return "";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
+ return "_Z17__spirv_GroupIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
+ return "_Z17__spirv_GroupFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
+ return "_Z17__spirv_GroupSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
+ return "_Z17__spirv_GroupUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
+ return "_Z17__spirv_GroupFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
+ return "_Z17__spirv_GroupSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
+ return "_Z17__spirv_GroupUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
+ return "_Z17__spirv_GroupFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
+ return "_Z27__spirv_GroupNonUniformIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
+ return "_Z27__spirv_GroupNonUniformFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
+ return "_Z27__spirv_GroupNonUniformIMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
+ return "_Z27__spirv_GroupNonUniformFMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
+ return "_Z27__spirv_GroupNonUniformSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
+ return "_Z27__spirv_GroupNonUniformUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
+ return "_Z27__spirv_GroupNonUniformFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
+ return "_Z32__spirv_GroupNonUniformBitwiseOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseXorii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
+ return "_Z32__spirv_GroupNonUniformLogicalOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalXorii";
+}
+} // namespace
+
+template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
+class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
+public:
+ using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type retTy = op.getResult().getType();
+ if (!retTy.isIntOrFloat()) {
+ return failure();
+ }
+ SmallString<20> funcName = getGroupFuncName<ReduceOp>();
+ funcName += getTypeMangling(retTy, false);
+
+ Type i32Ty = rewriter.getI32Type();
+ SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
+ if constexpr (NonUniform) {
+ if (adaptor.getClusterSize()) {
+ funcName += "j";
+ paramTypes.push_back(i32Ty);
+ }
+ }
+
+ Operation *symbolTable =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+
+ LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
+ symbolTable, funcName, paramTypes, retTy, !NonUniform);
+
+ Location loc = op.getLoc();
+ Value scope = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value groupOp = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
+ SmallVector<Value> operands{scope, groupOp};
+ operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
+
+ auto call =
+ createSPIRVBuiltinCall(loc, rewriter, func, operands, !NonUniform);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
@@ -1722,7 +1909,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ReturnPattern, ReturnValuePattern,
// Barrier ops
- ControlBarrierPattern>(patterns.getContext(), typeConverter);
+ ControlBarrierPattern,
+
+ // Group reduction operations
+ GroupReducePattern<spirv::GroupIAddOp>,
+ GroupReducePattern<spirv::GroupFAddOp>,
+ GroupReducePattern<spirv::GroupFMinOp>,
+ GroupReducePattern<spirv::GroupUMinOp>,
+ GroupReducePattern<spirv::GroupSMinOp, /*Signed*/ true>,
+ GroupReducePattern<spirv::GroupFMaxOp>,
+ GroupReducePattern<spirv::GroupUMaxOp>,
+ GroupReducePattern<spirv::GroupSMaxOp, /*Signed*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed*/ true,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed*/ true,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed*/ false,
+ /*NonUniform*/ true>
+ >(patterns.getContext(), typeConverter);
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..8c8fc50349e795
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
@@ -0,0 +1,312 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK-LABEL: llvm.func spir_funccc @_Z17__spirv_GroupSMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupUMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFMaxiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupSMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupUMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFMiniif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFAddiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupIAddiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL: llvm.func @group_reduce_iadd(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_iadd(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupIAdd <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fadd(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fadd(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFAdd <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fmin(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fmin(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMin <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_umin(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_umin(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupUMin <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_smin(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_smin(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupSMin <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fmax(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMaxiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fmax(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMax <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_umax(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_umax(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupUMax <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_smax(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_smax(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupSMax <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_iadd(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_inclusive_scan_iadd(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupIAdd <Workgroup> <InclusiveScan> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_fadd(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_inclusive_scan_fadd(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFAdd <Workgroup> <InclusiveScan> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_fmin(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_inclusive_scan_fmin(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMin <Workgroup> <InclusiveScan> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_umin(
+//...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Lukas Sommer (sommerlukas) ChangesLowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect. Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools. Patch is 45.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115501.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index ba31936af5510d..3fd78d2b08a6bb 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType) {
+ Type resultType,
+ bool convergent = true) {
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(symbolTable, name));
if (func)
@@ -1038,7 +1040,9 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
symbolTable->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes));
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
- func.setConvergent(true);
+ if (convergent) {
+ func.setConvergent(true);
+ }
func.setNoUnwind(true);
func.setWillReturn(true);
return func;
@@ -1046,10 +1050,13 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
LLVM::LLVMFuncOp func,
- ValueRange args) {
+ ValueRange args,
+ bool convergent = true) {
auto call = builder.create<LLVM::CallOp>(loc, func, args);
call.setCConv(func.getCConv());
- call.setConvergentAttr(func.getConvergentAttr());
+ if (convergent) {
+ call.setConvergentAttr(func.getConvergentAttr());
+ }
call.setNoUnwindAttr(func.getNoUnwindAttr());
call.setWillReturnAttr(func.getWillReturnAttr());
return call;
@@ -1089,6 +1096,186 @@ class ControlBarrierPattern
}
};
+namespace {
+
+StringRef getTypeMangling(Type type, bool isSigned) {
+ return llvm::TypeSwitch<Type, StringRef>(type)
+ .Case<Float16Type>([](auto) { return "Dh"; })
+ .template Case<Float32Type>([](auto) { return "f"; })
+ .template Case<Float64Type>([](auto) { return "d"; })
+ .template Case<IntegerType>([isSigned](IntegerType intTy) {
+ switch (intTy.getWidth()) {
+ case 1:
+ return "b";
+ case 8:
+ return (isSigned) ? "a" : "c";
+ case 16:
+ return (isSigned) ? "s" : "t";
+ case 32:
+ return (isSigned) ? "i" : "j";
+ case 64:
+ return (isSigned) ? "l" : "m";
+ default: {
+ assert(false && "Unsupported integer width");
+ return "";
+ }
+ }
+ })
+ .Default([](auto) {
+ assert(false && "No mangling defined");
+ return "";
+ });
+}
+
+template <typename ReduceOp>
+constexpr StringLiteral getGroupFuncName() {
+ assert(false && "No builtin defined");
+ return "";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
+ return "_Z17__spirv_GroupIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
+ return "_Z17__spirv_GroupFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
+ return "_Z17__spirv_GroupSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
+ return "_Z17__spirv_GroupUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
+ return "_Z17__spirv_GroupFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
+ return "_Z17__spirv_GroupSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
+ return "_Z17__spirv_GroupUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
+ return "_Z17__spirv_GroupFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
+ return "_Z27__spirv_GroupNonUniformIAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
+ return "_Z27__spirv_GroupNonUniformFAddii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
+ return "_Z27__spirv_GroupNonUniformIMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
+ return "_Z27__spirv_GroupNonUniformFMulii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
+ return "_Z27__spirv_GroupNonUniformSMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
+ return "_Z27__spirv_GroupNonUniformUMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
+ return "_Z27__spirv_GroupNonUniformFMinii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformSMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformUMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
+ return "_Z27__spirv_GroupNonUniformFMaxii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
+ return "_Z32__spirv_GroupNonUniformBitwiseOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
+ return "_Z33__spirv_GroupNonUniformBitwiseXorii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalAndii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
+ return "_Z32__spirv_GroupNonUniformLogicalOrii";
+}
+template <>
+constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
+ return "_Z33__spirv_GroupNonUniformLogicalXorii";
+}
+} // namespace
+
+template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
+class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
+public:
+ using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type retTy = op.getResult().getType();
+ if (!retTy.isIntOrFloat()) {
+ return failure();
+ }
+ SmallString<20> funcName = getGroupFuncName<ReduceOp>();
+ funcName += getTypeMangling(retTy, false);
+
+ Type i32Ty = rewriter.getI32Type();
+ SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
+ if constexpr (NonUniform) {
+ if (adaptor.getClusterSize()) {
+ funcName += "j";
+ paramTypes.push_back(i32Ty);
+ }
+ }
+
+ Operation *symbolTable =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+
+ LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
+ symbolTable, funcName, paramTypes, retTy, !NonUniform);
+
+ Location loc = op.getLoc();
+ Value scope = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value groupOp = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
+ SmallVector<Value> operands{scope, groupOp};
+ operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
+
+ auto call =
+ createSPIRVBuiltinCall(loc, rewriter, func, operands, !NonUniform);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
@@ -1722,7 +1909,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ReturnPattern, ReturnValuePattern,
// Barrier ops
- ControlBarrierPattern>(patterns.getContext(), typeConverter);
+ ControlBarrierPattern,
+
+ // Group reduction operations
+ GroupReducePattern<spirv::GroupIAddOp>,
+ GroupReducePattern<spirv::GroupFAddOp>,
+ GroupReducePattern<spirv::GroupFMinOp>,
+ GroupReducePattern<spirv::GroupUMinOp>,
+ GroupReducePattern<spirv::GroupSMinOp, /*Signed*/ true>,
+ GroupReducePattern<spirv::GroupFMaxOp>,
+ GroupReducePattern<spirv::GroupUMaxOp>,
+ GroupReducePattern<spirv::GroupSMaxOp, /*Signed*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed*/ true,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed*/ true,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed*/ false,
+ /*NonUniform*/ true>,
+ GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed*/ false,
+ /*NonUniform*/ true>
+ >(patterns.getContext(), typeConverter);
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..8c8fc50349e795
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/group-ops-to-llvm.mlir
@@ -0,0 +1,312 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK-LABEL: llvm.func spir_funccc @_Z17__spirv_GroupSMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupUMaxiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFMaxiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupSMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupUMiniij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFMiniif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupFAddiif(i32, i32, f32) -> f32 attributes {convergent, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z17__spirv_GroupIAddiij(i32, i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL: llvm.func @group_reduce_iadd(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_iadd(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupIAdd <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fadd(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fadd(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFAdd <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fmin(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fmin(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMin <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_umin(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_umin(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupUMin <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_smin(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMiniij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_smin(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupSMin <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_fmax(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMaxiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_reduce_fmax(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMax <Workgroup> <Reduce> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_umax(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupUMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_umax(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupUMax <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_reduce_smax(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupSMaxiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_reduce_smax(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupSMax <Workgroup> <Reduce> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_iadd(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupIAddiij(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> i32
+// CHECK: llvm.return %[[VAL_3]] : i32
+// CHECK: }
+spirv.func @group_inclusive_scan_iadd(%arg0: i32) -> i32 "None" {
+ %0 = spirv.GroupIAdd <Workgroup> <InclusiveScan> %arg0 : i32
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_fadd(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFAddiif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_inclusive_scan_fadd(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFAdd <Workgroup> <InclusiveScan> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_fmin(
+// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.call spir_funccc @_Z17__spirv_GroupFMiniif(%[[VAL_1]], %[[VAL_2]], %[[VAL_0]]) {convergent, no_unwind, will_return} : (i32, i32, f32) -> f32
+// CHECK: llvm.return %[[VAL_3]] : f32
+// CHECK: }
+spirv.func @group_inclusive_scan_fmin(%arg0: f32) -> f32 "None" {
+ %0 = spirv.GroupFMin <Workgroup> <InclusiveScan> %arg0 : f32
+ spirv.ReturnValue %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @group_inclusive_scan_umin(
+//...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just nits really
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but nits
Signed-off-by: Lukas Sommer <[email protected]>
(no objections from me) |
…#115501) Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect. Similar to llvm#111864, lower the operations to builtin functions understood by SPIR-V tools. --------- Signed-off-by: Lukas Sommer <[email protected]>
Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect.
Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools.