Skip to content

Commit afb6daf

Browse files
authored
[clang][HLSL] Add WaveIsFirstLane() intrinsic (#103299)
This commits add the WaveIsFirstLane() hlsl intrinsinc. This intrinsic uses the convergence intrinsincs for the SPIR-V backend. On the DXIL side, I'm not sure what the strategy is for convergence, so I implemented that like in DXC: a normal builtin function. Signed-off-by: Nathan Gauër <[email protected]>
1 parent 5914566 commit afb6daf

File tree

12 files changed

+140
-22
lines changed

12 files changed

+140
-22
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4679,6 +4679,12 @@ def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
46794679
let Prototype = "unsigned int()";
46804680
}
46814681

4682+
def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
4683+
let Spellings = ["__builtin_hlsl_wave_is_first_lane"];
4684+
let Attributes = [NoThrow, Const];
4685+
let Prototype = "bool()";
4686+
}
4687+
46824688
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
46834689
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
46844690
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18723,6 +18723,10 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1872318723
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
1872418724
{}, false, true));
1872518725
}
18726+
case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
18727+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
18728+
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
18729+
}
1872618730
}
1872718731
return nullptr;
1872818732
}

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class CGHLSLRuntime {
8484
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
8585
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
8686
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
87+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
8788

8889
//===----------------------------------------------------------------------===//
8990
// End of reserved area for HLSL intrinsic getters.

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,5 +1796,9 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
17961796
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_get_lane_index)
17971797
__attribute__((convergent)) uint WaveGetLaneIndex();
17981798

1799+
_HLSL_AVAILABILITY(shadermodel, 6.0)
1800+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
1801+
__attribute__((convergent)) bool WaveIsFirstLane();
1802+
17991803
} // namespace hlsl
18001804
#endif //_HLSL_HLSL_INTRINSICS_H_
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
5+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
7+
8+
[numthreads(1, 1, 1)]
9+
void main() {
10+
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
11+
12+
// CHECK-SPIRV: %[[#loop_tok:]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[#entry_tok]]) ]
13+
while (true) {
14+
15+
// CHECK-DXIL: %[[#]] = call i1 @llvm.dx.wave.is.first.lane()
16+
// CHECK-SPIRV: %[[#]] = call i1 @llvm.spv.wave.is.first.lane()
17+
// CHECK-SPIRV-SAME: [ "convergencectrl"(token %[[#loop_tok]]) ]
18+
if (WaveIsFirstLane()) {
19+
break;
20+
}
21+
}
22+
23+
// CHECK-DXIL: %[[#]] = call i1 @llvm.dx.wave.is.first.lane()
24+
// CHECK-SPIRV: %[[#]] = call i1 @llvm.spv.wave.is.first.lane()
25+
// CHECK-SPIRV-SAME: [ "convergencectrl"(token %[[#entry_tok]]) ]
26+
if (WaveIsFirstLane()) {
27+
return;
28+
}
29+
}
30+
31+
// CHECK-DXIL: i1 @llvm.dx.wave.is.first.lane() #[[#attr:]]
32+
// CHECK-SPIRV: i1 @llvm.spv.wave.is.first.lane() #[[#attr:]]
33+
34+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,6 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
7979
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
8080
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
8181
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
82+
83+
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8284
}

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,5 @@ let TargetPrefix = "spv" in {
7979
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
8080
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8181
[IntrNoMem, Commutative] >;
82+
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8283
}

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,3 +746,12 @@ def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
746746
let result = HandleTy;
747747
let stages = [Stages<DXIL1_6, [all_stages]>];
748748
}
749+
750+
def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
751+
let Doc = "returns 1 for the first lane in the wave";
752+
let LLVMIntrinsic = int_dx_wave_is_first_lane;
753+
let arguments = [];
754+
let result = Int1Ty;
755+
let stages = [Stages<DXIL1_0, [all_stages]>];
756+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
757+
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,6 +2351,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
23512351
} break;
23522352
case Intrinsic::spv_saturate:
23532353
return selectSaturate(ResVReg, ResType, I);
2354+
case Intrinsic::spv_wave_is_first_lane: {
2355+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
2356+
return BuildMI(BB, I, I.getDebugLoc(),
2357+
TII.get(SPIRV::OpGroupNonUniformElect))
2358+
.addDef(ResVReg)
2359+
.addUse(GR.getSPIRVTypeID(ResType))
2360+
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
2361+
}
23542362
default: {
23552363
std::string DiagMsg;
23562364
raw_string_ostream OS(DiagMsg);

llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,40 @@ class SPIRVStripConvergentIntrinsics : public FunctionPass {
4141
virtual bool runOnFunction(Function &F) override {
4242
DenseSet<Instruction *> ToRemove;
4343

44+
// Is the instruction is a convergent intrinsic, add it to kill-list and
45+
// returns true. Returns false otherwise.
46+
auto CleanupIntrinsic = [&](IntrinsicInst *II) {
47+
if (II->getIntrinsicID() != Intrinsic::experimental_convergence_entry &&
48+
II->getIntrinsicID() != Intrinsic::experimental_convergence_loop &&
49+
II->getIntrinsicID() != Intrinsic::experimental_convergence_anchor)
50+
return false;
51+
52+
II->replaceAllUsesWith(UndefValue::get(II->getType()));
53+
ToRemove.insert(II);
54+
return true;
55+
};
56+
57+
// Replace the given CallInst by a similar CallInst with no convergencectrl
58+
// attribute.
59+
auto CleanupCall = [&](CallInst *CI) {
60+
auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
61+
if (!OB.has_value())
62+
return;
63+
64+
auto *NewCall = CallBase::removeOperandBundle(
65+
CI, LLVMContext::OB_convergencectrl, CI);
66+
NewCall->copyMetadata(*CI);
67+
CI->replaceAllUsesWith(NewCall);
68+
ToRemove.insert(CI);
69+
};
70+
4471
for (BasicBlock &BB : F) {
4572
for (Instruction &I : BB) {
46-
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
47-
if (II->getIntrinsicID() !=
48-
Intrinsic::experimental_convergence_entry &&
49-
II->getIntrinsicID() !=
50-
Intrinsic::experimental_convergence_loop &&
51-
II->getIntrinsicID() !=
52-
Intrinsic::experimental_convergence_anchor) {
73+
if (auto *II = dyn_cast<IntrinsicInst>(&I))
74+
if (CleanupIntrinsic(II))
5375
continue;
54-
}
55-
56-
II->replaceAllUsesWith(UndefValue::get(II->getType()));
57-
ToRemove.insert(II);
58-
} else if (auto *CI = dyn_cast<CallInst>(&I)) {
59-
auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
60-
if (!OB.has_value())
61-
continue;
62-
63-
auto *NewCall = CallBase::removeOperandBundle(
64-
CI, LLVMContext::OB_convergencectrl, CI);
65-
NewCall->copyMetadata(*CI);
66-
CI->replaceAllUsesWith(NewCall);
67-
ToRemove.insert(CI);
68-
}
76+
if (auto *CI = dyn_cast<CallInst>(&I))
77+
CleanupCall(CI);
6978
}
7079
}
7180

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define void @main() #0 {
4+
entry:
5+
; CHECK: call i1 @dx.op.waveIsFirstLane(i32 110)
6+
%0 = call i1 @llvm.dx.wave.is.first.lane()
7+
ret void
8+
}
9+
10+
declare i1 @llvm.dx.wave.is.first.lane() #1
11+
12+
attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
13+
attributes #1 = { convergent nocallback nofree nosync nounwind willreturn }
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
5+
target triple = "spirv-unknown-vulkan-compute"
6+
7+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
8+
; CHECK-DAG: %[[#uint_3:]] = OpConstant %[[#uint]] 3
9+
; CHECK-DAG: %[[#bool:]] = OpTypeBool
10+
11+
define spir_func void @main() #0 {
12+
entry:
13+
%0 = call token @llvm.experimental.convergence.entry()
14+
; CHECK: %[[#]] = OpGroupNonUniformElect %[[#bool]] %[[#uint_3]]
15+
%1 = call i1 @llvm.spv.wave.is.first.lane() [ "convergencectrl"(token %0) ]
16+
ret void
17+
}
18+
19+
declare i32 @__hlsl_wave_get_lane_index() #1
20+
21+
attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
22+
attributes #1 = { convergent }
23+
24+
!llvm.module.flags = !{!0, !1}
25+
26+
!0 = !{i32 1, !"wchar_size", i32 4}
27+
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}

0 commit comments

Comments
 (0)