Skip to content

Commit 1f15252

Browse files
[SPIR-V] Add support for HLSL SV_GroupIndex (#130670)
This PR lowers the `llvm.spv.flattened.thread.id.in.group` intrinsic as a `LocalInvocationIndex` builtin variable.
1 parent 5497709 commit 1f15252

File tree

4 files changed

+82
-0
lines changed

4 files changed

+82
-0
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ SPIR-V backend, along with their descriptions and argument details.
405405
- 32-bit Integer
406406
- `[32-bit Integer]`
407407
- Retrieves the thread ID within a workgroup. Essential for identifying execution context in parallel compute operations.
408+
* - `int_spv_flattened_thread_id_in_group`
409+
- 32-bit Integer
410+
- `[32-bit Integer]`
411+
- Provides a flattened index for a given thread within a given group (SV_GroupIndex)
408412
* - `int_spv_create_handle`
409413
- Pointer
410414
- `[8-bit Integer]`

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ let TargetPrefix = "spv" in {
6161
def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
6262
def int_spv_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
6363
def int_spv_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
64+
def int_spv_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>;
6465
def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
6566
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
6667
def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
341341
bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
342342
Register ResVReg, const SPIRVType *ResType,
343343
MachineInstr &I) const;
344+
bool loadBuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
345+
Register ResVReg, const SPIRVType *ResType,
346+
MachineInstr &I) const;
344347
bool loadHandleBeforePosition(Register &HandleReg, const SPIRVType *ResType,
345348
GIntrinsic &HandleDef, MachineInstr &Pos) const;
346349
};
@@ -3065,6 +3068,15 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
30653068
// builtin variable
30663069
return loadVec3BuiltinInputID(SPIRV::BuiltIn::WorkgroupId, ResVReg, ResType,
30673070
I);
3071+
case Intrinsic::spv_flattened_thread_id_in_group:
3072+
// The HLSL SV_GroupIndex semantic is lowered to
3073+
// llvm.spv.flattened.thread.id.in.group() intrinsic in LLVM IR for SPIR-V
3074+
// backend.
3075+
//
3076+
// In SPIR-V backend, llvm.spv.flattened.thread.id.in.group is translated to
3077+
// a `LocalInvocationIndex` builtin variable
3078+
return loadBuiltinInputID(SPIRV::BuiltIn::LocalInvocationIndex, ResVReg,
3079+
ResType, I);
30683080
case Intrinsic::spv_fdot:
30693081
return selectFloatDot(ResVReg, ResType, I);
30703082
case Intrinsic::spv_udot:
@@ -4011,6 +4023,40 @@ bool SPIRVInstructionSelector::loadVec3BuiltinInputID(
40114023
return Result && MIB.constrainAllUses(TII, TRI, RBI);
40124024
}
40134025

4026+
// Generate the instructions to load 32-bit integer builtin input IDs/Indices.
4027+
// Like LocalInvocationIndex
4028+
bool SPIRVInstructionSelector::loadBuiltinInputID(
4029+
SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg,
4030+
const SPIRVType *ResType, MachineInstr &I) const {
4031+
MachineIRBuilder MIRBuilder(I);
4032+
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
4033+
ResType, MIRBuilder, SPIRV::StorageClass::Input);
4034+
4035+
// Create new register for the input ID builtin variable.
4036+
Register NewRegister =
4037+
MIRBuilder.getMRI()->createVirtualRegister(GR.getRegClass(PtrType));
4038+
MIRBuilder.getMRI()->setType(
4039+
NewRegister,
4040+
LLT::pointer(storageClassToAddressSpace(SPIRV::StorageClass::Input),
4041+
GR.getPointerSize()));
4042+
GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
4043+
4044+
// Build global variable with the necessary decorations for the input ID
4045+
// builtin variable.
4046+
Register Variable = GR.buildGlobalVariable(
4047+
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr,
4048+
SPIRV::StorageClass::Input, nullptr, true, true,
4049+
SPIRV::LinkageType::Import, MIRBuilder, false);
4050+
4051+
// Load uint value from the global variable.
4052+
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
4053+
.addDef(ResVReg)
4054+
.addUse(GR.getSPIRVTypeID(ResType))
4055+
.addUse(Variable);
4056+
4057+
return MIB.constrainAllUses(TII, TRI, RBI);
4058+
}
4059+
40144060
SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type,
40154061
MachineInstr &I) const {
40164062
MachineIRBuilder MIRBuilder(I);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
3+
4+
; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#ptr_Input_int:]] = OpTypePointer Input %[[#int]]
6+
; CHECK-DAG: %[[#LocalInvocationIndex:]] = OpVariable %[[#ptr_Input_int]] Input
7+
8+
; CHECK-DAG: OpEntryPoint GLCompute {{.*}} %[[#LocalInvocationIndex]]
9+
; CHECK-DAG: OpName %[[#LocalInvocationIndex]] "__spirv_BuiltInLocalInvocationIndex"
10+
; CHECK-DAG: OpDecorate %[[#LocalInvocationIndex]] BuiltIn LocalInvocationIndex
11+
12+
target triple = "spirv-unknown-vulkan-library"
13+
14+
declare void @local_index_user(i32)
15+
16+
; Function Attrs: convergent noinline norecurse
17+
define void @main() #1 {
18+
entry:
19+
20+
; CHECK: %[[#load:]] = OpLoad %[[#int]] %[[#LocalInvocationIndex]]
21+
%1 = call i32 @llvm.spv.flattened.thread.id.in.group()
22+
23+
call spir_func void @local_index_user(i32 %1)
24+
ret void
25+
}
26+
27+
; Function Attrs: nounwind willreturn memory(none)
28+
declare i32 @llvm.spv.flattened.thread.id.in.group() #3
29+
30+
attributes #1 = { convergent noinline norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
31+
attributes #3 = { nounwind willreturn memory(none) }

0 commit comments

Comments
 (0)