Skip to content

Commit 76245d4

Browse files
committed
[SPIR-V] Add support for HLSL SV_GroupIndex
This PR lowers the llvm.spv.flattened.thread.id.in.group intrinsic as a `LocalInvocationIndex` builtin variable.
1 parent fbf0276 commit 76245d4

File tree

4 files changed

+81
-0
lines changed

4 files changed

+81
-0
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ SPIR-V backend, along with their descriptions and argument details.
405405
- 32-bit Integer
406406
- `[32-bit Integer]`
407407
- Retrieves the thread ID within a workgroup. Essential for identifying execution context in parallel compute operations.
408+
* - `int_spv_flattened_thread_id_in_group`
409+
- 32-bit Integer
410+
- `[32-bit Integer]`
411+
- Provides a flattened index for a given thread within a given group (SV_GroupIndex)
408412
* - `int_spv_create_handle`
409413
- Pointer
410414
- `[8-bit Integer]`

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ let TargetPrefix = "spv" in {
6161
def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
6262
def int_spv_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
6363
def int_spv_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
64+
def int_spv_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>;
6465
def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
6566
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
6667
def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
341341
bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
342342
Register ResVReg, const SPIRVType *ResType,
343343
MachineInstr &I) const;
344+
bool loadBuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
345+
Register ResVReg, const SPIRVType *ResType,
346+
MachineInstr &I) const;
344347
bool loadHandleBeforePosition(Register &HandleReg, const SPIRVType *ResType,
345348
GIntrinsic &HandleDef, MachineInstr &Pos) const;
346349
};
@@ -3059,6 +3062,15 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
30593062
// builtin variable
30603063
return loadVec3BuiltinInputID(SPIRV::BuiltIn::WorkgroupId, ResVReg, ResType,
30613064
I);
3065+
case Intrinsic::spv_flattened_thread_id_in_group:
3066+
// The HLSL SV_GroupIndex semantic is lowered to
3067+
// llvm.spv.flattened.thread.id.in.group() intrinsic in LLVM IR for SPIR-V
3068+
// backend.
3069+
//
3070+
// In SPIR-V backend, llvm.spv.flattened.thread.id.in.group is translated to
3071+
// a `LocalInvocationIndex` builtin variable
3072+
return loadBuiltinInputID(SPIRV::BuiltIn::LocalInvocationIndex, ResVReg,
3073+
ResType, I);
30623074
case Intrinsic::spv_fdot:
30633075
return selectFloatDot(ResVReg, ResType, I);
30643076
case Intrinsic::spv_udot:
@@ -4005,6 +4017,38 @@ bool SPIRVInstructionSelector::loadVec3BuiltinInputID(
40054017
return Result && MIB.constrainAllUses(TII, TRI, RBI);
40064018
}
40074019

4020+
// Generate the instructions to load 32-bit integer builtin input IDs/Indices.
4021+
// Like LocalInvocationIndex
4022+
bool SPIRVInstructionSelector::loadBuiltinInputID(
4023+
SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg,
4024+
const SPIRVType *ResType, MachineInstr &I) const {
4025+
MachineIRBuilder MIRBuilder(I);
4026+
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
4027+
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
4028+
U32Type, MIRBuilder, SPIRV::StorageClass::Input);
4029+
4030+
// Create new register for the input ID builtin variable.
4031+
Register NewRegister =
4032+
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
4033+
MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 64));
4034+
GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
4035+
4036+
// Build global variable with the necessary decorations for the input ID
4037+
// builtin variable.
4038+
Register Variable = GR.buildGlobalVariable(
4039+
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr,
4040+
SPIRV::StorageClass::Input, nullptr, true, true,
4041+
SPIRV::LinkageType::Import, MIRBuilder, false);
4042+
4043+
// Load uint value from the global variable.
4044+
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
4045+
.addDef(ResVReg)
4046+
.addUse(GR.getSPIRVTypeID(U32Type))
4047+
.addUse(Variable);
4048+
4049+
return MIB.constrainAllUses(TII, TRI, RBI);
4050+
}
4051+
40084052
SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type,
40094053
MachineInstr &I) const {
40104054
MachineIRBuilder MIRBuilder(I);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#ptr_Input_int:]] = OpTypePointer Input %[[#int]]
6+
; CHECK-DAG: %[[#LocalInvocationIndex:]] = OpVariable %[[#ptr_Input_int]] Input
7+
8+
; CHECK-DAG: OpEntryPoint GLCompute {{.*}} %[[#LocalInvocationIndex]]
9+
; CHECK-DAG: OpName %[[#LocalInvocationIndex]] "__spirv_BuiltInLocalInvocationIndex"
10+
; CHECK-DAG: OpDecorate %[[#LocalInvocationIndex]] LinkageAttributes "__spirv_BuiltInLocalInvocationIndex" Import
11+
; CHECK-DAG: OpDecorate %[[#LocalInvocationIndex]] BuiltIn LocalInvocationIndex
12+
13+
target triple = "spirv-unknown-vulkan-library"
14+
15+
declare void @local_index_user(i32)
16+
17+
; Function Attrs: convergent noinline norecurse
18+
define void @main() #1 {
19+
entry:
20+
21+
; CHECK: %[[#load:]] = OpLoad %[[#int]] %[[#LocalInvocationIndex]]
22+
%1 = call i32 @llvm.spv.flattened.thread.id.in.group()
23+
24+
call spir_func void @local_index_user(i32 %1)
25+
ret void
26+
}
27+
28+
; Function Attrs: nounwind willreturn memory(none)
29+
declare i32 @llvm.spv.flattened.thread.id.in.group() #3
30+
31+
attributes #1 = { convergent noinline norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
32+
attributes #3 = { nounwind willreturn memory(none) }

0 commit comments

Comments
 (0)