Skip to content

Commit 68c16dc

Browse files
inbelicFinn Plummer
authored andcommitted
[DXIL][SPIRV] Lower WaveActiveCountBits intrinsic
- add codegen for llvm builtin to spirv/directx intrinsic in CGBuiltin.cpp - add lowering of spirv intrinsic to spirv backend in SPIRVInstructionSelector.cpp - add lowering of directx intrinsic to dxil op in DXIL.td - add test cases to illustrate passes - add test case for semantic analysis
1 parent 40ea92c commit 68c16dc

File tree

10 files changed

+125
-0
lines changed

10 files changed

+125
-0
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18879,6 +18879,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1887918879
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
1888018880
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
1888118881
}
18882+
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
18883+
Value *OpExpr = EmitScalarExpr(E->getArg(0));
18884+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
18885+
return EmitRuntimeCall(
18886+
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
18887+
ArrayRef{OpExpr});
18888+
}
1888218889
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1888318890
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
1888418891
// defined in SPIRVBuiltins.td. So instead we manually get the matching name

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class CGHLSLRuntime {
8989
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
9090
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
92+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
9293
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
9394
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
9495

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test_bool
11+
int test_bool(bool expr) {
12+
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
13+
// CHECK-SPIRV: %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}})
15+
// CHECK: ret i32 %[[RET]]
16+
return WaveActiveCountBits(expr);
17+
}
18+
19+
// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]]
20+
// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]]
21+
22+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
int test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_count_bits();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
int test_too_many_arg(bool x) {
9+
return __builtin_hlsl_wave_active_count_bits(x, x);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
struct S { float f; };
14+
15+
int test_bad_conversion(S x) {
16+
return __builtin_hlsl_wave_active_count_bits(x);
17+
// expected-error@-1 {{no viable conversion from 'S' to 'bool'}}
18+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
8484
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
8585
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8686
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
87+
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
8788
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
8889
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8990
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
8383
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
8484
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8585
[IntrNoMem, Commutative] >;
86+
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
8687
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8788
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
8889
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,3 +820,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
820820
let stages = [Stages<DXIL1_0, [all_stages]>];
821821
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
822822
}
823+
824+
def WaveAllBitCount : DXILOp<135, waveAllOp> {
825+
let Doc = "returns the count of bits set to 1 across the wave";
826+
let LLVMIntrinsic = int_dx_wave_active_countbits;
827+
let arguments = [Int1Ty];
828+
let result = Int32Ty;
829+
let stages = [Stages<DXIL1_0, [all_stages]>];
830+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
831+
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
230230
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
231231
MachineInstr &I) const;
232232

233+
bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
234+
MachineInstr &I) const;
235+
233236
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
234237
MachineInstr &I) const;
235238

@@ -1762,6 +1765,38 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
17621765
return Result;
17631766
}
17641767

1768+
bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
1769+
const SPIRVType *ResType,
1770+
MachineInstr &I) const {
1771+
assert(I.getNumOperands() == 3);
1772+
assert(I.getOperand(2).isReg());
1773+
MachineBasicBlock &BB = *I.getParent();
1774+
1775+
Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1776+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
1777+
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
1778+
1779+
bool Result =
1780+
BuildMI(BB, I, I.getDebugLoc(),
1781+
TII.get(SPIRV::OpGroupNonUniformBallot))
1782+
.addDef(BallotReg)
1783+
.addUse(GR.getSPIRVTypeID(BallotType))
1784+
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
1785+
.addUse(I.getOperand(2).getReg());
1786+
1787+
Result |=
1788+
BuildMI(BB, I, I.getDebugLoc(),
1789+
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
1790+
.addDef(ResVReg)
1791+
.addUse(GR.getSPIRVTypeID(ResType))
1792+
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
1793+
.addImm(0)
1794+
.addUse(BallotReg)
1795+
.constrainAllUses(TII, TRI, RBI);
1796+
1797+
return Result;
1798+
}
1799+
17651800
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
17661801
const SPIRVType *ResType,
17671802
MachineInstr &I) const {
@@ -2559,6 +2594,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
25592594
} break;
25602595
case Intrinsic::spv_saturate:
25612596
return selectSaturate(ResVReg, ResType, I);
2597+
case Intrinsic::spv_wave_active_countbits:
2598+
return selectWaveActiveCountBits(ResVReg, ResType, I);
25622599
case Intrinsic::spv_wave_is_first_lane: {
25632600
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
25642601
return BuildMI(BB, I, I.getDebugLoc(),
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define void @main(i1 %expr) {
4+
entry:
5+
; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
6+
%0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
7+
ret void
8+
}
9+
10+
declare i32 @llvm.dx.wave.active.countbits(i1)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
6+
; CHECK-DAG: %[[#bool:]] = OpTypeBool
7+
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
8+
9+
; CHECK-LABEL: Begin function test_fun
10+
; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
11+
define i32 @test_fun(i1 %expr) {
12+
entry:
13+
; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
14+
; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
15+
%0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
16+
ret i32 %0
17+
}
18+
19+
declare i32 @llvm.dx.wave.active.countbits(i1)

0 commit comments

Comments
 (0)