Skip to content

Commit e520b28

Browse files
authored
[DXIL][SPIRV] Lower WaveActiveCountBits intrinsic (#113382)
``` - add codegen for llvm builtin to spirv/directx intrinsic in CGBuiltin.cpp - add lowering of spirv intrinsic to spirv backend in SPIRVInstructionSelector.cpp - add lowering of directx intrinsic to dxil op in DXIL.td - add test cases to illustrate passes - add test case for semantic analysis ``` Resolves #80176
1 parent 4fb43c4 commit e520b28

File tree

10 files changed

+120
-0
lines changed

10 files changed

+120
-0
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19092,6 +19092,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1909219092
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
1909319093
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
1909419094
}
19095+
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
19096+
Value *OpExpr = EmitScalarExpr(E->getArg(0));
19097+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
19098+
return EmitRuntimeCall(
19099+
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
19100+
ArrayRef{OpExpr});
19101+
}
1909519102
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1909619103
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
1909719104
// defined in SPIRVBuiltins.td. So instead we manually get the matching name

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class CGHLSLRuntime {
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9292
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
9393
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
94+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
9495
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
9596
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
9697
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s -DTARGET=dx
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s -DTARGET=spv
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test_bool
11+
int test_bool(bool expr) {
12+
// CHECK: call {{.*}} @llvm.[[TARGET]].wave.active.countbits
13+
return WaveActiveCountBits(expr);
14+
}
15+
16+
// CHECK: declare i32 @llvm.[[TARGET]].wave.active.countbits(i1) #[[#attr:]]
17+
18+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
int test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_count_bits();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
int test_too_many_arg(bool x) {
9+
return __builtin_hlsl_wave_active_count_bits(x, x);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
struct S { float f; };
14+
15+
int test_bad_conversion(S x) {
16+
return __builtin_hlsl_wave_active_count_bits(x);
17+
// expected-error@-1 {{no viable conversion from 'S' to 'bool'}}
18+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
9090
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
9191
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
9292
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
93+
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
9394
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
9495
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9596
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ let TargetPrefix = "spv" in {
8585
[IntrNoMem, Commutative] >;
8686
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8787
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
88+
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
8889
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8990
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
9091
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,3 +880,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
880880
let stages = [Stages<DXIL1_0, [all_stages]>];
881881
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
882882
}
883+
884+
def WaveAllBitCount : DXILOp<135, waveAllOp> {
885+
let Doc = "returns the count of bits set to 1 across the wave";
886+
let LLVMIntrinsic = int_dx_wave_active_countbits;
887+
let arguments = [Int1Ty];
888+
let result = Int32Ty;
889+
let stages = [Stages<DXIL1_0, [all_stages]>];
890+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
891+
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
256256
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
257257
MachineInstr &I) const;
258258

259+
bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
260+
MachineInstr &I) const;
261+
259262
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
260263
MachineInstr &I) const;
261264

@@ -1917,6 +1920,37 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
19171920
return Result;
19181921
}
19191922

1923+
bool SPIRVInstructionSelector::selectWaveActiveCountBits(
1924+
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
1925+
assert(I.getNumOperands() == 3);
1926+
assert(I.getOperand(2).isReg());
1927+
MachineBasicBlock &BB = *I.getParent();
1928+
1929+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
1930+
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
1931+
Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType));
1932+
1933+
bool Result =
1934+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
1935+
.addDef(BallotReg)
1936+
.addUse(GR.getSPIRVTypeID(BallotType))
1937+
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
1938+
.addUse(I.getOperand(2).getReg())
1939+
.constrainAllUses(TII, TRI, RBI);
1940+
1941+
Result &=
1942+
BuildMI(BB, I, I.getDebugLoc(),
1943+
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
1944+
.addDef(ResVReg)
1945+
.addUse(GR.getSPIRVTypeID(ResType))
1946+
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
1947+
.addImm(SPIRV::GroupOperation::Reduce)
1948+
.addUse(BallotReg)
1949+
.constrainAllUses(TII, TRI, RBI);
1950+
1951+
return Result;
1952+
}
1953+
19201954
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
19211955
const SPIRVType *ResType,
19221956
MachineInstr &I) const {
@@ -2745,6 +2779,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
27452779
return selectExtInst(ResVReg, ResType, I, CL::u_clamp, GL::UClamp);
27462780
case Intrinsic::spv_sclamp:
27472781
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
2782+
case Intrinsic::spv_wave_active_countbits:
2783+
return selectWaveActiveCountBits(ResVReg, ResType, I);
27482784
case Intrinsic::spv_wave_is_first_lane: {
27492785
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
27502786
return BuildMI(BB, I, I.getDebugLoc(),
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define void @main(i1 %expr) {
4+
entry:
5+
; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
6+
%0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
7+
ret void
8+
}
9+
10+
declare i32 @llvm.dx.wave.active.countbits(i1)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
6+
; CHECK-DAG: %[[#bool:]] = OpTypeBool
7+
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
8+
9+
; CHECK-LABEL: Begin function test_fun
10+
; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
11+
define i32 @test_fun(i1 %expr) {
12+
entry:
13+
; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
14+
; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
15+
%0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
16+
ret i32 %0
17+
}
18+
19+
declare i32 @llvm.dx.wave.active.countbits(i1)

0 commit comments

Comments
 (0)