Skip to content

[HLSL] Implement SV_GroupThreadId semantic #117781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4651,6 +4651,13 @@ def HLSLNumThreads: InheritableAttr {
let Documentation = [NumThreadsDocs];
}

def HLSLSV_GroupThreadID: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"SV_GroupThreadID">];
let Subjects = SubjectList<[ParmVar, Field]>;
let LangOpts = [HLSL];
let Documentation = [HLSLSV_GroupThreadIDDocs];
}

def HLSLSV_GroupID: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"SV_GroupID">];
let Subjects = SubjectList<[ParmVar, Field]>;
Expand Down
11 changes: 11 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -7941,6 +7941,17 @@ randomized.
}];
}

def HLSLSV_GroupThreadIDDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
The ``SV_GroupThreadID`` semantic, when applied to an input parameter, specifies which
individual thread within a thread group is executing in. This attribute is
only supported in compute shaders.

The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupthreadid
}];
}

def HLSLSV_GroupIDDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
CGM.getIntrinsic(getThreadIdIntrinsic());
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
}
if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
llvm::Function *GroupThreadIDIntrinsic =
CGM.getIntrinsic(getGroupThreadIdIntrinsic());
return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
}
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
return buildVectorInput(B, GroupIDIntrinsic, Ty);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
GENERATE_HLSL_INTRINSIC_FUNCTION(Radians, radians)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
GENERATE_HLSL_INTRINSIC_FUNCTION(GroupThreadId, thread_id_in_group)
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Parse/ParseHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
case ParsedAttr::UnknownAttribute:
Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
return;
case ParsedAttr::AT_HLSLSV_GroupThreadID:
case ParsedAttr::AT_HLSLSV_GroupID:
case ParsedAttr::AT_HLSLSV_GroupIndex:
case ParsedAttr::AT_HLSLSV_DispatchThreadID:
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7114,6 +7114,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLWaveSize:
S.HLSL().handleWaveSizeAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupThreadID:
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupID:
S.HLSL().handleSV_GroupIDAttr(D, AL);
break;
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
case attr::HLSLSV_GroupThreadID:
case attr::HLSLSV_GroupID:
if (ST == llvm::Triple::Compute)
return;
Expand Down Expand Up @@ -787,6 +788,15 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
}

void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!diagnoseInputIDType(VD->getType(), AL))
return;

D->addAttr(::new (getASTContext())
HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
}

void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!diagnoseInputIDType(VD->getType(), AL))
Expand Down
36 changes: 36 additions & 0 deletions clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv

// Make sure SV_GroupThreadID translated into dx.thread.id.in.group for directx target and spv.thread.id.in.group for spirv target.

// CHECK: define void @foo()
// CHECK: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
// CHECK-DXIL: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(i32 %[[#ID]])
[shader("compute")]
[numthreads(8,8,1)]
void foo(uint Idx : SV_GroupThreadID) {}

// CHECK: define void @bar()
// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1)
// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
// CHECK-DXIL: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
// CHECK-SPIRV: call spir_func void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
[shader("compute")]
[numthreads(8,8,1)]
void bar(uint2 Idx : SV_GroupThreadID) {}

// CHECK: define void @test()
// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0)
// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1)
// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
// CHECK: %[[#ID_Z:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 2)
// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
// CHECK-DXIL: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
// CHECK-SPIRV: call spir_func void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
[shader("compute")]
[numthreads(8,8,1)]
void test(uint3 Idx : SV_GroupThreadID) {}
10 changes: 4 additions & 6 deletions clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -hlsl-entry CSMain -x hlsl -finclude-default-header -ast-dump -o - %s | FileCheck %s
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s

[numthreads(8,8,1)]
// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
// CHECK-NEXT: HLSLSV_GroupIndexAttr
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
// CHECK-NEXT: HLSLSV_GroupIDAttr
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:96 GThreadID 'uint'
// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
}
30 changes: 30 additions & 0 deletions clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,33 @@ struct ST2_GID {
static uint GID : SV_GroupID;
uint s_gid : SV_GroupID;
};

[numthreads(8,8,1)]
// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
void CSMain_GThreadID(float ID : SV_GroupThreadID) {
}

[numthreads(8,8,1)]
// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
void CSMain2_GThreadID(ST GID : SV_GroupThreadID) {

}

void foo_GThreadID() {
// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
uint GThreadIS : SV_GroupThreadID;
}

struct ST2_GThreadID {
// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
static uint GThreadID : SV_GroupThreadID;
uint s_gthreadid : SV_GroupThreadID;
};


[shader("vertex")]
// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'vertex' shaders, requires compute}}
// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'vertex' shaders, requires compute}}
// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'vertex' shaders, requires compute}}
// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'vertex' shaders, requires compute}}
void vs_main(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {}
25 changes: 25 additions & 0 deletions clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,28 @@ void CSMain3_GID(uint3 : SV_GroupID) {
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
// CHECK-NEXT: HLSLSV_GroupIDAttr
}

[numthreads(8,8,1)]
void CSMain_GThreadID(uint ID : SV_GroupThreadID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GThreadID 'void (uint)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:28 ID 'uint'
// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
}
[numthreads(8,8,1)]
void CSMain1_GThreadID(uint2 ID : SV_GroupThreadID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GThreadID 'void (uint2)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint2'
// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
}
[numthreads(8,8,1)]
void CSMain2_GThreadID(uint3 ID : SV_GroupThreadID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GThreadID 'void (uint3)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint3'
// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
}
[numthreads(8,8,1)]
void CSMain3_GThreadID(uint3 : SV_GroupThreadID) {
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GThreadID 'void (uint3)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 'uint3'
// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ let TargetPrefix = "spv" in {

// The following intrinsic(s) are mirrored from IntrinsicsDirectX.td for HLSL support.
def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
def int_spv_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
Expand Down
47 changes: 30 additions & 17 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,6 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSaturate(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, unsigned Opcode) const;

Expand Down Expand Up @@ -310,6 +307,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
void extractSubvector(Register &ResVReg, const SPIRVType *ResType,
Register &ReadReg, MachineInstr &InsertionPoint) const;
bool BuildCOPY(Register DestReg, Register SrcReg, MachineInstr &I) const;
bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue,
Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
};

} // end anonymous namespace
Expand Down Expand Up @@ -2825,7 +2825,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return BuildCOPY(ResVReg, I.getOperand(2).getReg(), I);
break;
case Intrinsic::spv_thread_id:
return selectSpvThreadId(ResVReg, ResType, I);
// The HLSL SV_DispatchThreadID semantic is lowered to llvm.spv.thread.id
// intrinsic in LLVM IR for SPIR-V backend.
//
// In SPIR-V backend, llvm.spv.thread.id is now correctly translated to a
// `GlobalInvocationId` builtin variable
return loadVec3BuiltinInputID(SPIRV::BuiltIn::GlobalInvocationId, ResVReg,
ResType, I);
case Intrinsic::spv_thread_id_in_group:
// The HLSL SV_GroupThreadId semantic is lowered to
// llvm.spv.thread.id.in.group intrinsic in LLVM IR for SPIR-V backend.
//
// In SPIR-V backend, llvm.spv.thread.id.in.group is now correctly
// translated to a `LocalInvocationId` builtin variable
return loadVec3BuiltinInputID(SPIRV::BuiltIn::LocalInvocationId, ResVReg,
ResType, I);
case Intrinsic::spv_fdot:
return selectFloatDot(ResVReg, ResType, I);
case Intrinsic::spv_udot:
Expand Down Expand Up @@ -3525,30 +3539,29 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
// DX intrinsic: @llvm.dx.thread.id(i32)
// ID Name Description
// 93 ThreadId reads the thread ID

// Generate the instructions to load 3-element vector builtin input
// IDs/Indices.
// Like: GlobalInvocationId, LocalInvocationId, etc....
bool SPIRVInstructionSelector::loadVec3BuiltinInputID(
SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg,
const SPIRVType *ResType, MachineInstr &I) const {
MachineIRBuilder MIRBuilder(I);
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
const SPIRVType *Vec3Ty =
GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);

// Create new register for GlobalInvocationID builtin variable.
// Create new register for the input ID builtin variable.
Register NewRegister =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 64));
GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());

// Build GlobalInvocationID global variable with the necessary decorations.
// Build global variable with the necessary decorations for the input ID
// builtin variable.
Register Variable = GR.buildGlobalVariable(
NewRegister, PtrType,
getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr,
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr,
SPIRV::StorageClass::Input, nullptr, true, true,
SPIRV::LinkageType::Import, MIRBuilder, false);

Expand All @@ -3565,12 +3578,12 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
.addUse(GR.getSPIRVTypeID(Vec3Ty))
.addUse(Variable);

// Get Thread ID index. Expecting operand is a constant immediate value,
// Get the input ID index. Expecting operand is a constant immediate value,
// wrapped in a type assignment.
assert(I.getOperand(2).isReg());
const uint32_t ThreadId = foldImm(I.getOperand(2), MRI);

// Extract the thread ID from the loaded vector value.
// Extract the input ID from the loaded vector value.
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(ResVReg)
Expand Down
Loading
Loading