Skip to content

[DXIL][SPIRV] Lower WaveActiveCountBits intrinsic #113382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19092,6 +19092,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
ArrayRef{OpExpr});
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
// defined in SPIRVBuiltins.td. So instead we manually get the matching name
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
Expand Down
18 changes: 18 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s -DTARGET=spv

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_bool
int test_bool(bool expr) {
// CHECK: call {{.*}} @llvm.[[TARGET]].wave.active.countbits
return WaveActiveCountBits(expr);
}

// CHECK: declare i32 @llvm.[[TARGET]].wave.active.countbits(i1) #[[#attr:]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
18 changes: 18 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

int test_too_few_arg() {
return __builtin_hlsl_wave_active_count_bits();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

int test_too_many_arg(bool x) {
return __builtin_hlsl_wave_active_count_bits(x, x);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

struct S { float f; };

int test_bad_conversion(S x) {
return __builtin_hlsl_wave_active_count_bits(x);
// expected-error@-1 {{no viable conversion from 'S' to 'bool'}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ let TargetPrefix = "spv" in {
[IntrNoMem, Commutative] >;
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveAllBitCount : DXILOp<135, waveAllOp> {
let Doc = "returns the count of bits set to 1 across the wave";
let LLVMIntrinsic = int_dx_wave_active_countbits;
let arguments = [Int1Ty];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
36 changes: 36 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Expand Down Expand Up @@ -1917,6 +1920,37 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}

bool SPIRVInstructionSelector::selectWaveActiveCountBits(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();

SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType));

bool Result =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
.addDef(BallotReg)
.addUse(GR.getSPIRVTypeID(BallotType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);

Result &=
BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(BallotReg)
.constrainAllUses(TII, TRI, RBI);

return Result;
}

bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2745,6 +2779,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectExtInst(ResVReg, ResType, I, CL::u_clamp, GL::UClamp);
case Intrinsic::spv_sclamp:
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
case Intrinsic::spv_wave_active_countbits:
return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_wave_is_first_lane: {
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define void @main(i1 %expr) {
entry:
; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
%0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
ret void
}

declare i32 @llvm.dx.wave.active.countbits(i1)
19 changes: 19 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
; CHECK-DAG: %[[#bool:]] = OpTypeBool
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3

; CHECK-LABEL: Begin function test_fun
; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
define i32 @test_fun(i1 %expr) {
entry:
; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
%0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
ret i32 %0
}

declare i32 @llvm.dx.wave.active.countbits(i1)
Loading