Skip to content

Commit f0eb908

Browse files
authored
[SPIR-V] Add WaveGetLaneIndex() intrinsic support (#85979)
Add support to generate valid SPIR-V for the WaveGetLaneIndex() HLSL builtin. To implement this, I had to fix a few small issues in the backend, like the i8* pointer type being emitted, even if we have the type information elsewhere. Signed-off-by: Nathan Gauër <[email protected]>
1 parent 37785fe commit f0eb908

File tree

10 files changed

+131
-34
lines changed

10 files changed

+131
-34
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,10 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
370370

371371
/// Helper function for building a load instruction for loading a builtin global
372372
/// variable of \p BuiltinValue value.
373-
static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
374-
SPIRVType *VariableType,
375-
SPIRVGlobalRegistry *GR,
376-
SPIRV::BuiltIn::BuiltIn BuiltinValue,
377-
LLT LLType,
378-
Register Reg = Register(0)) {
373+
static Register buildBuiltinVariableLoad(
374+
MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
375+
SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
376+
Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
379377
Register NewRegister =
380378
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
381379
MIRBuilder.getMRI()->setType(NewRegister,
@@ -387,8 +385,9 @@ static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder,
387385
// Set up the global OpVariable with the necessary builtin decorations.
388386
Register Variable = GR->buildGlobalVariable(
389387
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
390-
SPIRV::StorageClass::Input, nullptr, true, true,
391-
SPIRV::LinkageType::Import, MIRBuilder, false);
388+
SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
389+
/* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
390+
false);
392391

393392
// Load the value from the global variable.
394393
Register LoadedRegister =
@@ -1341,6 +1340,22 @@ static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
13411340
return true;
13421341
}
13431342

1343+
static bool generateWaveInst(const SPIRV::IncomingCall *Call,
1344+
MachineIRBuilder &MIRBuilder,
1345+
SPIRVGlobalRegistry *GR) {
1346+
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1347+
SPIRV::BuiltIn::BuiltIn Value =
1348+
SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value;
1349+
1350+
// For now, we only support a single Wave intrinsic with a single return type.
1351+
assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt);
1352+
LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(Call->ReturnType));
1353+
1354+
return buildBuiltinVariableLoad(
1355+
MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
1356+
/* isConst= */ false, /* hasLinkageTy= */ false);
1357+
}
1358+
13441359
static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
13451360
MachineIRBuilder &MIRBuilder,
13461361
SPIRVGlobalRegistry *GR) {
@@ -2229,6 +2244,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
22292244
return generateBarrierInst(Call.get(), MIRBuilder, GR);
22302245
case SPIRV::Dot:
22312246
return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
2247+
case SPIRV::Wave:
2248+
return generateWaveInst(Call.get(), MIRBuilder, GR);
22322249
case SPIRV::GetQuery:
22332250
return generateGetQueryInst(Call.get(), MIRBuilder, GR);
22342251
case SPIRV::ImageSizeQuery:

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def Variable : BuiltinGroup;
4141
def Atomic : BuiltinGroup;
4242
def Barrier : BuiltinGroup;
4343
def Dot : BuiltinGroup;
44+
def Wave : BuiltinGroup;
4445
def GetQuery : BuiltinGroup;
4546
def ImageSizeQuery : BuiltinGroup;
4647
def ImageMiscQuery : BuiltinGroup;
@@ -1143,6 +1144,7 @@ defm : DemangledGetBuiltin<"get_global_size", OpenCL_std, GetQuery, GlobalSize>;
11431144
defm : DemangledGetBuiltin<"get_group_id", OpenCL_std, GetQuery, WorkgroupId>;
11441145
defm : DemangledGetBuiltin<"get_enqueued_local_size", OpenCL_std, GetQuery, EnqueuedWorkgroupSize>;
11451146
defm : DemangledGetBuiltin<"get_num_groups", OpenCL_std, GetQuery, NumWorkgroups>;
1147+
defm : DemangledGetBuiltin<"__hlsl_wave_get_lane_index", GLSL_std_450, Wave, SubgroupLocalInvocationId>;
11461148

11471149
//===----------------------------------------------------------------------===//
11481150
// Class defining an image query builtin record used for lowering the OpenCL

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
493493
Register ResVReg =
494494
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
495495
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
496-
// TODO: check that it's OCL builtin, then apply OpenCL_std.
497-
if (!DemangledName.empty() && CF && CF->isDeclaration() &&
498-
ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
496+
497+
bool isFunctionDecl = CF && CF->isDeclaration();
498+
bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
499+
bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
500+
assert(canUseGLSL != canUseOpenCL &&
501+
"Scenario where both sets are enabled is not supported.");
502+
503+
if (isFunctionDecl && !DemangledName.empty() &&
504+
(canUseGLSL || canUseOpenCL)) {
499505
SmallVector<Register, 8> ArgVRegs;
500506
for (auto Arg : Info.OrigArgs) {
501507
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
@@ -504,12 +510,15 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
504510
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
505511
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
506512
}
507-
if (auto Res = SPIRV::lowerBuiltin(
508-
DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
509-
ResVReg, OrigRetTy, ArgVRegs, GR))
513+
auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
514+
: SPIRV::InstructionSet::GLSL_std_450;
515+
if (auto Res =
516+
SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
517+
ResVReg, OrigRetTy, ArgVRegs, GR))
510518
return *Res;
511519
}
512-
if (CF && CF->isDeclaration() && !GR->find(CF, &MF).isValid()) {
520+
521+
if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
513522
// Emit the type info and forward function declaration to the first MBB
514523
// to ensure VReg definition dependencies are valid across all MBBs.
515524
MachineIRBuilder FirstBlockBuilder;

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
723723
AddrSpace = PType->getAddressSpace();
724724
else
725725
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
726-
SPIRVType *SpvElementType;
727-
// At the moment, all opaque pointers correspond to i8 element type.
728-
// TODO: change the implementation once opaque pointers are supported
729-
// in the SPIR-V specification.
730-
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
726+
727+
SPIRVType *SpvElementType = nullptr;
728+
if (auto PType = dyn_cast<TypedPointerType>(Ty))
729+
SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,
730+
AccQual, EmitIR);
731+
else
732+
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
733+
731734
// Get access to information about available extensions
732735
const SPIRVSubtarget *ST =
733736
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ void RequirementHandler::initAvailableCapabilitiesForVulkan(
658658

659659
// Provided by all supported Vulkan versions.
660660
addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
661-
Capability::Float64});
661+
Capability::Float64, Capability::GroupNonUniform});
662662
}
663663

664664
} // namespace SPIRV

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
186186
}
187187
case TargetOpcode::G_GLOBAL_VALUE: {
188188
MIB.setInsertPt(*MI->getParent(), MI);
189-
Type *Ty = MI->getOperand(1).getGlobal()->getType();
189+
const auto *Global = MI->getOperand(1).getGlobal();
190+
auto *Ty = TypedPointerType::get(Global->getValueType(),
191+
Global->getType()->getAddressSpace());
190192
SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
191193
break;
192194
}

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,20 +306,19 @@ static bool isNonMangledOCLBuiltin(StringRef Name) {
306306
std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) {
307307
bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name);
308308
bool IsNonMangledSPIRV = Name.starts_with("__spirv_");
309+
bool IsNonMangledHLSL = Name.starts_with("__hlsl_");
309310
bool IsMangled = Name.starts_with("_Z");
310311

311-
if (!IsNonMangledOCL && !IsNonMangledSPIRV && !IsMangled)
312-
return std::string();
312+
// Otherwise use simple demangling to return the function name.
313+
if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled)
314+
return Name.str();
313315

314316
// Try to use the itanium demangler.
315317
if (char *DemangledName = itaniumDemangle(Name.data())) {
316318
std::string Result = DemangledName;
317319
free(DemangledName);
318320
return Result;
319321
}
320-
// Otherwise use simple demangling to return the function name.
321-
if (IsNonMangledOCL || IsNonMangledSPIRV)
322-
return Name.str();
323322

324323
// Autocheck C++, maybe need to do explicit check of the source language.
325324
// OpenCL C++ built-ins are declared in cl namespace.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; This file generated from the following command:
5+
; clang -cc1 -triple spirv-vulkan-compute -x hlsl -emit-llvm -finclude-default-header -o - - <<EOF
6+
; [numthreads(1, 1, 1)]
7+
; void main() {
8+
; int idx = WaveGetLaneIndex();
9+
; }
10+
; EOF
11+
12+
; CHECK-DAG: OpCapability Shader
13+
; CHECK-DAG: OpCapability GroupNonUniform
14+
; CHECK-DAG: OpDecorate %[[#var:]] BuiltIn SubgroupLocalInvocationId
15+
; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
16+
; CHECK-DAG: %[[#ptri:]] = OpTypePointer Input %[[#int]]
17+
; CHECK-DAG: %[[#ptrf:]] = OpTypePointer Function %[[#int]]
18+
; CHECK-DAG: %[[#var]] = OpVariable %[[#ptri]] Input
19+
20+
; CHECK-NOT: OpDecorate %[[#var]] LinkageAttributes
21+
22+
23+
; ModuleID = '-'
24+
source_filename = "-"
25+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
26+
target triple = "spirv-unknown-vulkan-compute"
27+
28+
; Function Attrs: convergent noinline norecurse nounwind optnone
29+
define internal spir_func void @main() #0 {
30+
entry:
31+
%0 = call token @llvm.experimental.convergence.entry()
32+
%idx = alloca i32, align 4
33+
; CHECK: %[[#idx:]] = OpVariable %[[#ptrf]] Function
34+
35+
%1 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %0) ]
36+
; CHECK: %[[#tmp:]] = OpLoad %[[#int]] %[[#var]]
37+
38+
store i32 %1, ptr %idx, align 4
39+
; CHECK: OpStore %[[#idx]] %[[#tmp]]
40+
41+
ret void
42+
}
43+
44+
; Function Attrs: norecurse
45+
define void @main.1() #1 {
46+
entry:
47+
call void @main()
48+
ret void
49+
}
50+
51+
; Function Attrs: convergent
52+
declare i32 @__hlsl_wave_get_lane_index() #2
53+
54+
; Function Attrs: convergent nocallback nofree nosync nounwind willreturn memory(none)
55+
declare token @llvm.experimental.convergence.entry() #3
56+
57+
attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
58+
attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
59+
attributes #2 = { convergent }
60+
attributes #3 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
61+
62+
!llvm.module.flags = !{!0, !1}
63+
!llvm.ident = !{!2}
64+
65+
!0 = !{i32 1, !"wchar_size", i32 4}
66+
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
67+
!2 = !{!"clang version 19.0.0git (/usr/local/google/home/nathangauer/projects/llvm-project/clang bc6fd04b73a195981ee77823cf1382d04ab96c44)"}
68+

llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; RUN: llc -mtriple=spirv-unknown-unknown -O0 %s -o - | FileCheck %s
22

3+
; CHECK-DAG: OpDecorate %[[#SubgroupLocalInvocationId:]] BuiltIn SubgroupLocalInvocationId
34
; CHECK-DAG: %[[#bool:]] = OpTypeBool
45
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
56
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
@@ -37,10 +38,10 @@ l1_continue:
3738
; CHECK-NEXT: OpBranch %[[#l1_header]]
3839

3940
l1_end:
40-
%call = call spir_func i32 @_Z3absi(i32 0) [ "convergencectrl"(token %tl1) ]
41+
%call = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %tl1) ]
4142
br label %end
4243
; CHECK-DAG: %[[#l1_end]] = OpLabel
43-
; CHECK-DAG: %[[#]] = OpFunctionCall
44+
; CHECK-DAG: %[[#]] = OpLoad %[[#]] %[[#SubgroupLocalInvocationId]]
4445
; CHECK-NEXT: OpBranch %[[#end:]]
4546

4647
l2:
@@ -76,6 +77,4 @@ declare token @llvm.experimental.convergence.entry()
7677
declare token @llvm.experimental.convergence.control()
7778
declare token @llvm.experimental.convergence.loop()
7879

79-
; This intrinsic is not convergent. This is only because the backend doesn't
80-
; support convergent operations yet.
81-
declare spir_func i32 @_Z3absi(i32) convergent
80+
declare i32 @__hlsl_wave_get_lane_index() convergent

llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
22
;
3-
; CHECK-SPIRV-DAG: %[[#i8:]] = OpTypeInt 8 0
43
; CHECK-SPIRV-DAG: %[[#i32:]] = OpTypeInt 32 0
54
; CHECK-SPIRV-DAG: %[[#one:]] = OpConstant %[[#i32]] 1
65
; CHECK-SPIRV-DAG: %[[#two:]] = OpConstant %[[#i32]] 2
@@ -13,7 +12,6 @@
1312
; CHECK-SPIRV: %[[#test_arr2:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
1413
; CHECK-SPIRV: %[[#test_arr:]] = OpVariable %[[#const_i32x3_ptr]] UniformConstant %[[#test_arr_init]]
1514

16-
; CHECK-SPIRV-DAG: %[[#const_i8_ptr:]] = OpTypePointer UniformConstant %[[#i8]]
1715
; CHECK-SPIRV-DAG: %[[#i32x3_ptr:]] = OpTypePointer Function %[[#i32x3]]
1816

1917
; CHECK-SPIRV: %[[#arr:]] = OpVariable %[[#i32x3_ptr]] Function

0 commit comments

Comments
 (0)