Skip to content

Commit 28f8234

Browse files
committed
[HLSL][SPIR-V] Add SV_GroupThreadID semantic support
The HLSL SV_GroupThreadID semantic attribute is lowered into @llvm.spv.thread.id.in.group intrinsic in LLVM IR for SPIR-V target. In the SPIR-V backend, this is now correctly translated to a `LocalInvocationId` builtin variable. Fixes #70122
1 parent dc8d779 commit 28f8234

File tree

6 files changed

+144
-29
lines changed

6 files changed

+144
-29
lines changed

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
391391
}
392392
if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
393393
llvm::Function *GroupThreadIDIntrinsic =
394-
CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group);
394+
CGM.getIntrinsic(getGroupThreadIdIntrinsic());
395395
return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
396396
}
397397
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class CGHLSLRuntime {
8686
GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
8787
GENERATE_HLSL_INTRINSIC_FUNCTION(Radians, radians)
8888
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
89+
GENERATE_HLSL_INTRINSIC_FUNCTION(GroupThreadId, thread_id_in_group)
8990
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
9091
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
9192
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,36 @@
1-
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
2+
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv
23

3-
// Make sure SV_GroupThreadID translated into dx.thread.id.in.group.
4+
// Make sure SV_GroupThreadID translated into dx.thread.id.in.group for directx target and spv.thread.id.in.group for spirv target.
45

5-
// CHECK: define void @foo()
6-
// CHECK: %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
7-
// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
6+
// CHECK: define void @foo()
7+
// CHECK: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
8+
// CHECK-DXIL: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
9+
// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(i32 %[[#ID]])
810
[shader("compute")]
911
[numthreads(8,8,1)]
1012
void foo(uint Idx : SV_GroupThreadID) {}
1113

12-
// CHECK: define void @bar()
13-
// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
14-
// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
15-
// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
16-
// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
17-
// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
14+
// CHECK: define void @bar()
15+
// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
16+
// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
17+
// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1)
18+
// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
19+
// CHECK-DXIL: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
20+
// CHECK-SPIRV: call spir_func void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
1821
[shader("compute")]
1922
[numthreads(8,8,1)]
2023
void bar(uint2 Idx : SV_GroupThreadID) {}
2124

2225
// CHECK: define void @test()
23-
// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
26+
// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
2427
// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
25-
// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
28+
// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1)
2629
// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
27-
// CHECK: %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2)
30+
// CHECK: %[[#ID_Z:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 2)
2831
// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
29-
// CHECK: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
32+
// CHECK-DXIL: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
33+
// CHECK-SPIRV: call spir_func void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
3034
[shader("compute")]
3135
[numthreads(8,8,1)]
3236
void test(uint3 Idx : SV_GroupThreadID) {}

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ let TargetPrefix = "spv" in {
5959

6060
// The following intrinsic(s) are mirrored from IntrinsicsDirectX.td for HLSL support.
6161
def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
62+
def int_spv_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
6263
def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
6364
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
6465
def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
265265
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
266266
MachineInstr &I) const;
267267

268+
bool selectSpvGroupThreadId(Register ResVReg, const SPIRVType *ResType,
269+
MachineInstr &I) const;
270+
268271
bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType,
269272
MachineInstr &I, unsigned Opcode) const;
270273

@@ -309,6 +312,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
309312
SPIRVType *widenTypeToVec4(const SPIRVType *Type, MachineInstr &I) const;
310313
void extractSubvector(Register &ResVReg, const SPIRVType *ResType,
311314
Register &ReadReg, MachineInstr &InsertionPoint) const;
315+
bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
316+
Register ResVReg, const SPIRVType *ResType,
317+
MachineInstr &I) const;
312318
};
313319

314320
} // end anonymous namespace
@@ -2852,6 +2858,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
28522858
break;
28532859
case Intrinsic::spv_thread_id:
28542860
return selectSpvThreadId(ResVReg, ResType, I);
2861+
case Intrinsic::spv_thread_id_in_group:
2862+
return selectSpvGroupThreadId(ResVReg, ResType, I);
28552863
case Intrinsic::spv_fdot:
28562864
return selectFloatDot(ResVReg, ResType, I);
28572865
case Intrinsic::spv_udot:
@@ -3551,30 +3559,29 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
35513559
.constrainAllUses(TII, TRI, RBI);
35523560
}
35533561

3554-
bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
3555-
const SPIRVType *ResType,
3556-
MachineInstr &I) const {
3557-
// DX intrinsic: @llvm.dx.thread.id(i32)
3558-
// ID Name Description
3559-
// 93 ThreadId reads the thread ID
3560-
3562+
// Generate the instructions to load 3-element vector builtin input
3563+
// IDs/Indices.
3564+
// Like: SV_DispatchThreadID, SV_GroupThreadID, etc....
3565+
bool SPIRVInstructionSelector::loadVec3BuiltinInputID(
3566+
SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg,
3567+
const SPIRVType *ResType, MachineInstr &I) const {
35613568
MachineIRBuilder MIRBuilder(I);
35623569
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
35633570
const SPIRVType *Vec3Ty =
35643571
GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
35653572
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
35663573
Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
35673574

3568-
// Create new register for GlobalInvocationID builtin variable.
3575+
// Create new register for the input ID builtin variable.
35693576
Register NewRegister =
35703577
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
35713578
MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 64));
35723579
GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
35733580

3574-
// Build GlobalInvocationID global variable with the necessary decorations.
3581+
// Build global variable with the necessary decorations for the input ID
3582+
// builtin variable.
35753583
Register Variable = GR.buildGlobalVariable(
3576-
NewRegister, PtrType,
3577-
getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr,
3584+
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr,
35783585
SPIRV::StorageClass::Input, nullptr, true, true,
35793586
SPIRV::LinkageType::Import, MIRBuilder, false);
35803587

@@ -3591,12 +3598,12 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
35913598
.addUse(GR.getSPIRVTypeID(Vec3Ty))
35923599
.addUse(Variable);
35933600

3594-
// Get Thread ID index. Expecting operand is a constant immediate value,
3601+
// Get the input ID index. Expecting operand is a constant immediate value,
35953602
// wrapped in a type assignment.
35963603
assert(I.getOperand(2).isReg());
35973604
const uint32_t ThreadId = foldImm(I.getOperand(2), MRI);
35983605

3599-
// Extract the thread ID from the loaded vector value.
3606+
// Extract the input ID from the loaded vector value.
36003607
MachineBasicBlock &BB = *I.getParent();
36013608
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
36023609
.addDef(ResVReg)
@@ -3606,6 +3613,32 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
36063613
return Result && MIB.constrainAllUses(TII, TRI, RBI);
36073614
}
36083615

3616+
bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
3617+
const SPIRVType *ResType,
3618+
MachineInstr &I) const {
3619+
// DX intrinsic: @llvm.dx.thread.id(i32)
3620+
// ID Name Description
3621+
// 93 ThreadId reads the thread ID
3622+
//
3623+
// In SPIR-V, llvm.dx.thread.id maps to a `GlobalInvocationId` builtin
3624+
// variable
3625+
return loadVec3BuiltinInputID(SPIRV::BuiltIn::GlobalInvocationId, ResVReg,
3626+
ResType, I);
3627+
}
3628+
3629+
bool SPIRVInstructionSelector::selectSpvGroupThreadId(Register ResVReg,
3630+
const SPIRVType *ResType,
3631+
MachineInstr &I) const {
3632+
// DX intrinsic: @llvm.dx.thread.id.in.group(i32)
3633+
// ID Name Description
3634+
// 95 GroupThreadId Reads the thread ID within the group
3635+
//
3636+
// In SPIR-V, llvm.dx.thread.id.in.group maps to a `LocalInvocationId` builtin
3637+
// variable
3638+
return loadVec3BuiltinInputID(SPIRV::BuiltIn::LocalInvocationId, ResVReg,
3639+
ResType, I);
3640+
}
3641+
36093642
SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type,
36103643
MachineInstr &I) const {
36113644
MachineIRBuilder MIRBuilder(I);
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; This file generated from the following command:
5+
; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header - -o - <<EOF
6+
; [shader("compute")]
7+
; [numthreads(1,1,1)]
8+
; void main(uint3 ID : SV_GroupThreadID) {}
9+
; EOF
10+
11+
; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
12+
; CHECK-DAG: %[[#v3int:]] = OpTypeVector %[[#int]] 3
13+
; CHECK-DAG: %[[#ptr_Input_v3int:]] = OpTypePointer Input %[[#v3int]]
14+
; CHECK-DAG: %[[#tempvar:]] = OpUndef %[[#v3int]]
15+
; CHECK-DAG: %[[#LocalInvocationId:]] = OpVariable %[[#ptr_Input_v3int]] Input
16+
17+
; CHECK-DAG: OpEntryPoint GLCompute {{.*}} %[[#LocalInvocationId]]
18+
; CHECK-DAG: OpName %[[#LocalInvocationId]] "__spirv_BuiltInLocalInvocationId"
19+
; CHECK-DAG: OpDecorate %[[#LocalInvocationId]] LinkageAttributes "__spirv_BuiltInLocalInvocationId" Import
20+
; CHECK-DAG: OpDecorate %[[#LocalInvocationId]] BuiltIn LocalInvocationId
21+
22+
; ModuleID = '-'
23+
source_filename = "-"
24+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
25+
target triple = "spirv-unknown-vulkan-library"
26+
27+
; Function Attrs: noinline norecurse nounwind optnone
28+
define internal spir_func void @main(<3 x i32> noundef %ID) #0 {
29+
entry:
30+
%ID.addr = alloca <3 x i32>, align 16
31+
store <3 x i32> %ID, ptr %ID.addr, align 16
32+
ret void
33+
}
34+
35+
; Function Attrs: norecurse
36+
define void @main.1() #1 {
37+
entry:
38+
39+
; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]]
40+
; CHECK: %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0
41+
%0 = call i32 @llvm.spv.thread.id.in.group(i32 0)
42+
43+
; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0
44+
%1 = insertelement <3 x i32> poison, i32 %0, i64 0
45+
46+
; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]]
47+
; CHECK: %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1
48+
%2 = call i32 @llvm.spv.thread.id.in.group(i32 1)
49+
50+
; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1
51+
%3 = insertelement <3 x i32> %1, i32 %2, i64 1
52+
53+
; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]]
54+
; CHECK: %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2
55+
%4 = call i32 @llvm.spv.thread.id.in.group(i32 2)
56+
57+
; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2
58+
%5 = insertelement <3 x i32> %3, i32 %4, i64 2
59+
60+
call void @main(<3 x i32> %5)
61+
ret void
62+
}
63+
64+
; Function Attrs: nounwind willreturn memory(none)
65+
declare i32 @llvm.spv.thread.id.in.group(i32) #2
66+
67+
attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
68+
attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
69+
attributes #2 = { nounwind willreturn memory(none) }
70+
71+
!llvm.module.flags = !{!0, !1}
72+
!llvm.ident = !{!2}
73+
74+
!0 = !{i32 1, !"wchar_size", i32 4}
75+
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
76+
!2 = !{!"clang version 19.0.0git ([email protected]:llvm/llvm-project.git 91600507765679e92434ec7c5edb883bf01f847f)"}

0 commit comments

Comments
 (0)