-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[HLSL] Implement SV_GroupThreadId semantic #117781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-clang Author: Zhengxing li (lizhengxing) ChangesSupport SV_GroupThreadId attribute. Fixes: #70122 Full diff: https://github.com/llvm/llvm-project/pull/117781.diff 11 Files Affected:
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index b055cbd769bb50..9c8e27c0f34e93 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4621,6 +4621,13 @@ def HLSLNumThreads: InheritableAttr {
let Documentation = [NumThreadsDocs];
}
+def HLSLSV_GroupThreadID: HLSLAnnotationAttr {
+ let Spellings = [HLSLAnnotation<"SV_GroupThreadID">];
+ let Subjects = SubjectList<[ParmVar, Field]>;
+ let LangOpts = [HLSL];
+ let Documentation = [HLSLSV_GroupThreadIDDocs];
+}
+
def HLSLSV_GroupID: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"SV_GroupID">];
let Subjects = SubjectList<[ParmVar, Field]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index aafd4449e47004..88bf9a020586ea 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7934,6 +7934,17 @@ randomized.
}];
}
+def HLSLSV_GroupThreadIDDocs : Documentation {
+ let Category = DocCatFunction;
+ let Content = [{
+The ``SV_GroupThreadID`` semantic, when applied to an input parameter, specifies which
+individual thread within a thread group is executing in. This attribute is
+only supported in compute shaders.
+
+The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupthreadid
+ }];
+}
+
def HLSLSV_GroupIDDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index ee685d95c96154..f4cd11f423a84a 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
+ void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 2c293523fca8ca..19db7faddaeac0 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -389,6 +389,11 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
CGM.getIntrinsic(getThreadIdIntrinsic());
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
}
+ if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
+ llvm::Function *GroupThreadIDIntrinsic =
+ CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group);
+ return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
+ }
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
return buildVectorInput(B, GroupIDIntrinsic, Ty);
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 4de342b63ed802..443bf2b9ec626a 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
case ParsedAttr::UnknownAttribute:
Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
return;
+ case ParsedAttr::AT_HLSLSV_GroupThreadID:
case ParsedAttr::AT_HLSLSV_GroupID:
case ParsedAttr::AT_HLSLSV_GroupIndex:
case ParsedAttr::AT_HLSLSV_DispatchThreadID:
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 53cc8cb6afd7dc..47e946c3ee64bc 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7103,6 +7103,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLWaveSize:
S.HLSL().handleWaveSizeAttr(D, AL);
break;
+ case ParsedAttr::AT_HLSLSV_GroupThreadID:
+ S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
+ break;
case ParsedAttr::AT_HLSLSV_GroupID:
S.HLSL().handleSV_GroupIDAttr(D, AL);
break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 8b2f24a8e4be0a..7f3c6cb566bcbf 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
+ case attr::HLSLSV_GroupThreadID:
case attr::HLSLSV_GroupID:
if (ST == llvm::Triple::Compute)
return;
@@ -787,6 +788,15 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
}
+void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
+ auto *VD = cast<ValueDecl>(D);
+ if (!diagnoseInputIDType(VD->getType(), AL))
+ return;
+
+ D->addAttr(::new (getASTContext())
+ HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
+}
+
void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!diagnoseInputIDType(VD->getType(), AL))
diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
new file mode 100644
index 00000000000000..3533331c6f091c
--- /dev/null
+++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
+
+// Make sure SV_GroupThreadID translated into dx.thread.id.in.group.
+
+// CHECK: define void @foo()
+// CHECK: %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void foo(uint Idx : SV_GroupThreadID) {}
+
+// CHECK: define void @bar()
+// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
+// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void bar(uint2 Idx : SV_GroupThreadID) {}
+
+// CHECK: define void @test()
+// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
+// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK: %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2)
+// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
+// CHECK: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void test(uint3 Idx : SV_GroupThreadID) {}
diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
index 13c07038d2e4a4..71d32cd13832e1 100644
--- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
@@ -2,15 +2,18 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s
[numthreads(8,8,1)]
-// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
-void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
-// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
+// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'mesh' shaders, requires compute}}
+void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)'
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
// CHECK-NEXT: HLSLSV_GroupIndexAttr
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
// CHECK-NEXT: HLSLSV_GroupIDAttr
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:96 GThreadID 'uint'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
}
diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
index 4e1f88aa2294b5..a24112c8e1bb8f 100644
--- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
@@ -49,3 +49,25 @@ struct ST2_GID {
static uint GID : SV_GroupID;
uint s_gid : SV_GroupID;
};
+
+[numthreads(8,8,1)]
+// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain_GThreadID(float ID : SV_GroupThreadID) {
+}
+
+[numthreads(8,8,1)]
+// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain2_GThreadID(ST GID : SV_GroupThreadID) {
+
+}
+
+void foo_GThreadID() {
+// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
+ uint GThreadIS : SV_GroupThreadID;
+}
+
+struct ST2_GThreadID {
+// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
+ static uint GThreadID : SV_GroupThreadID;
+ uint s_gthreadid : SV_GroupThreadID;
+};
diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
index 10a5e5dabac87b..6781f9241df240 100644
--- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
@@ -49,3 +49,28 @@ void CSMain3_GID(uint3 : SV_GroupID) {
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
// CHECK-NEXT: HLSLSV_GroupIDAttr
}
+
+[numthreads(8,8,1)]
+void CSMain_GThreadID(uint ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GThreadID 'void (uint)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:28 ID 'uint'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain1_GThreadID(uint2 ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GThreadID 'void (uint2)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint2'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain2_GThreadID(uint3 ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GThreadID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint3'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain3_GThreadID(uint3 : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GThreadID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 'uint3'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to consider how spirv handles this? Otherwise, LGTM
This PR is similar to the one for SV_GroupID (#70120). We don't consider the spirv part for both PRs. |
9d5ffe0
to
28f8234
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's good, with minor nit about unnecessary comments.
@sudonatalie |
Support SV_GroupThreadId attribute. Translate it into dx.thread.id.in.group in clang codeGen. Fixes: llvm#70122
The SV_GroupIndex, SV_DispatchThreadID, SV_GroupID and SV_GroupThreadID are actually legal for meash shader stage. It shouldn't test them with mesh shader. This commit tests them with vertex shader and move the test into invalid_entry_parameter.hlsl which's a better place for it.
The HLSL SV_GroupThreadID semantic attribute is lowered into @llvm.spv.thread.id.in.group intrinsic in LLVM IR for SPIR-V target. In the SPIR-V backend, this is now correctly translated to a `LocalInvocationId` builtin variable. Fixes llvm#70122
1a3355a
to
c05f268
Compare
Support HLSL SV_GroupThreadId attribute.
For
directx
target, translate it intodx.thread.id.in.group
in clang codeGen and lowerdx.thread.id.in.group
todx.op.threadIdInGroup
in LLVM DirectX backend.For
spir-v
target, translate it intospv.thread.id.in.group
in clang codeGen and lowerspv.thread.id.in.group
to aLocalInvocationId
builtin variable in LLVM SPIR-V backend.Fixes: #70122