-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[DXIL][SPIRV] Lower WaveActiveCountBits
intrinsic
#113382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang-codegen Author: Finn Plummer (inbelic) Changes
Full diff: https://github.com/llvm/llvm-project/pull/113382.diff 10 Files Affected:
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..f178499156eb1f 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18879,6 +18879,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
+ case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
+ Value *OpExpr = EmitScalarExpr(E->getArg(0));
+ Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
+ return EmitRuntimeCall(
+ Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
+ ArrayRef{OpExpr});
+ }
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
// defined in SPIRVBuiltins.td. So instead we manually get the matching name
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..ada57b070783c6 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -89,6 +89,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
new file mode 100755
index 00000000000000..3e1f8fcaace9c2
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_bool
+int test_bool(bool expr) {
+ // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ]
+ // CHECK-DXIL: %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}})
+ // CHECK: ret i32 %[[RET]]
+ return WaveActiveCountBits(expr);
+}
+
+// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]]
+// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
new file mode 100644
index 00000000000000..02f45eb30b377a
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+ return __builtin_hlsl_wave_active_count_bits();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+int test_too_many_arg(bool x) {
+ return __builtin_hlsl_wave_active_count_bits(x, x);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+struct S { float f; };
+
+int test_bad_conversion(S x) {
+ return __builtin_hlsl_wave_active_count_bits(x);
+ // expected-error@-1 {{no viable conversion from 'S' to 'bool'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..306829f6d1375d 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -84,6 +84,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 6df2eb156a0774..28bbd72b122208 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
+ def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..7cb6b0a1af9a33 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -820,3 +820,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
+
+def WaveAllBitCount : DXILOp<135, waveAllOp> {
+ let Doc = "returns the count of bits set to 1 across the wave";
+ let LLVMIntrinsic = int_dx_wave_active_countbits;
+ let arguments = [Int1Ty];
+ let result = Int32Ty;
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d9377fe4b91a1a..9f670f793574ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -230,6 +230,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1762,6 +1765,36 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}
+bool SPIRVInstructionSelector::selectWaveActiveCountBits(
+ Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
+
+ bool Result =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
+ .addDef(BallotReg)
+ .addUse(GR.getSPIRVTypeID(BallotType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+ .addUse(I.getOperand(2).getReg());
+
+ Result |=
+ BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+ .addImm(0)
+ .addUse(BallotReg)
+ .constrainAllUses(TII, TRI, RBI);
+
+ return Result;
+}
+
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2559,6 +2592,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
} break;
case Intrinsic::spv_saturate:
return selectSaturate(ResVReg, ResType, I);
+ case Intrinsic::spv_wave_active_countbits:
+ return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_wave_is_first_lane: {
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
new file mode 100644
index 00000000000000..5d321372433198
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define void @main(i1 %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
+ %0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
+ ret void
+}
+
+declare i32 @llvm.dx.wave.active.countbits(i1)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
new file mode 100644
index 00000000000000..29944054111ac0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
@@ -0,0 +1,19 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_fun
+; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
+define i32 @test_fun(i1 %expr) {
+entry:
+; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
+; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
+ %0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
+ ret i32 %0
+}
+
+declare i32 @llvm.dx.wave.active.countbits(i1)
|
Failing testcase is unrelated. Will need to rebase on review changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! The only change I think is essential is the missing constrainAllUses
6e58d11
to
0162e76
Compare
WaveActiveCountBits
intrinsic
- add codegen for llvm builtin to spirv/directx intrinsic in CGBuiltin.cpp - add lowering of spirv intrinsic to spirv backend in SPIRVInstructionSelector.cpp - add lowering of directx intrinsic to dxil op in DXIL.td - add test cases to illustrate passes - add test case for semantic analysis
- add constrainAllUses to first spirv op - update testcase for ease of reading - use enum instead of int equivalent for documentation
- get the proper register class - use result and instead of or
0162e76
to
cff8387
Compare
Rebased to resolve conflicts. |
``` - add codegen for llvm builtin to spirv/directx intrinsic in CGBuiltin.cpp - add lowering of spirv intrinsic to spirv backend in SPIRVInstructionSelector.cpp - add lowering of directx intrinsic to dxil op in DXIL.td - add test cases to illustrate passes - add test case for semantic analysis ``` Resolves llvm#80176
Resolves #80176