Skip to content

[HLSL] Implement SV_GroupThreadId semantic #117781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 10, 2024

Conversation

lizhengxing
Copy link
Contributor

@lizhengxing lizhengxing commented Nov 26, 2024

Support HLSL SV_GroupThreadId attribute.

For directx target, translate it into dx.thread.id.in.group in clang codeGen and lower dx.thread.id.in.group to dx.op.threadIdInGroup in LLVM DirectX backend.

For spir-v target, translate it into spv.thread.id.in.group in clang codeGen and lower spv.thread.id.in.group to a LocalInvocationId builtin variable in LLVM SPIR-V backend.

Fixes: #70122

@lizhengxing lizhengxing marked this pull request as ready for review November 26, 2024 20:27
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels Nov 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Zhengxing li (lizhengxing)

Changes

Support SV_GroupThreadId attribute.
Translate it into dx.thread.id.in.group in clang codeGen.

Fixes: #70122


Full diff: https://github.com/llvm/llvm-project/pull/117781.diff

11 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+7)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+11)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+1)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.cpp (+5)
  • (modified) clang/lib/Parse/ParseHLSL.cpp (+1)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+10)
  • (added) clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl (+32)
  • (modified) clang/test/SemaHLSL/Semantics/entry_parameter.hlsl (+8-5)
  • (modified) clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl (+22)
  • (modified) clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl (+25)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index b055cbd769bb50..9c8e27c0f34e93 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4621,6 +4621,13 @@ def HLSLNumThreads: InheritableAttr {
   let Documentation = [NumThreadsDocs];
 }
 
+def HLSLSV_GroupThreadID: HLSLAnnotationAttr {
+  let Spellings = [HLSLAnnotation<"SV_GroupThreadID">];
+  let Subjects = SubjectList<[ParmVar, Field]>;
+  let LangOpts = [HLSL];
+  let Documentation = [HLSLSV_GroupThreadIDDocs];
+}
+
 def HLSLSV_GroupID: HLSLAnnotationAttr {
   let Spellings = [HLSLAnnotation<"SV_GroupID">];
   let Subjects = SubjectList<[ParmVar, Field]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index aafd4449e47004..88bf9a020586ea 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7934,6 +7934,17 @@ randomized.
   }];
 }
 
+def HLSLSV_GroupThreadIDDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``SV_GroupThreadID`` semantic, when applied to an input parameter, specifies which
+individual thread within a thread group is executing in. This attribute is
+only supported in compute shaders.
+
+The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupthreadid
+  }];
+}
+
 def HLSLSV_GroupIDDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index ee685d95c96154..f4cd11f423a84a 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
+  void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
   void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
   void handleShaderAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 2c293523fca8ca..19db7faddaeac0 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -389,6 +389,11 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
         CGM.getIntrinsic(getThreadIdIntrinsic());
     return buildVectorInput(B, ThreadIDIntrinsic, Ty);
   }
+  if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
+    llvm::Function *GroupThreadIDIntrinsic =
+        CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group);
+    return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
+  }
   if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
     llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
     return buildVectorInput(B, GroupIDIntrinsic, Ty);
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 4de342b63ed802..443bf2b9ec626a 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
   case ParsedAttr::UnknownAttribute:
     Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
     return;
+  case ParsedAttr::AT_HLSLSV_GroupThreadID:
   case ParsedAttr::AT_HLSLSV_GroupID:
   case ParsedAttr::AT_HLSLSV_GroupIndex:
   case ParsedAttr::AT_HLSLSV_DispatchThreadID:
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 53cc8cb6afd7dc..47e946c3ee64bc 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7103,6 +7103,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLWaveSize:
     S.HLSL().handleWaveSizeAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLSV_GroupThreadID:
+    S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupID:
     S.HLSL().handleSV_GroupIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 8b2f24a8e4be0a..7f3c6cb566bcbf 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
   switch (AnnotationAttr->getKind()) {
   case attr::HLSLSV_DispatchThreadID:
   case attr::HLSLSV_GroupIndex:
+  case attr::HLSLSV_GroupThreadID:
   case attr::HLSLSV_GroupID:
     if (ST == llvm::Triple::Compute)
       return;
@@ -787,6 +788,15 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
 }
 
+void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
+  auto *VD = cast<ValueDecl>(D);
+  if (!diagnoseInputIDType(VD->getType(), AL))
+    return;
+
+  D->addAttr(::new (getASTContext())
+                 HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
+}
+
 void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
   auto *VD = cast<ValueDecl>(D);
   if (!diagnoseInputIDType(VD->getType(), AL))
diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
new file mode 100644
index 00000000000000..3533331c6f091c
--- /dev/null
+++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
+
+// Make sure SV_GroupThreadID translated into dx.thread.id.in.group.
+
+// CHECK:  define void @foo()
+// CHECK:  %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK:  call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void foo(uint Idx : SV_GroupThreadID) {}
+
+// CHECK:  define void @bar()
+// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK:  %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
+// CHECK:  %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK:  call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void bar(uint2 Idx : SV_GroupThreadID) {}
+
+// CHECK:  define void @test()
+// CHECK:  %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0)
+// CHECK:  %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0
+// CHECK:  %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1)
+// CHECK:  %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
+// CHECK:  %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2)
+// CHECK:  %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2
+// CHECK:  call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]])
+[shader("compute")]
+[numthreads(8,8,1)]
+void test(uint3 Idx : SV_GroupThreadID) {}
diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
index 13c07038d2e4a4..71d32cd13832e1 100644
--- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
@@ -2,15 +2,18 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header  -verify -o - %s
 
 [numthreads(8,8,1)]
-// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
-// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
-void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
-// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
+// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
+// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'mesh' shaders, requires compute}}
+void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)'
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
 // CHECK-NEXT: HLSLSV_GroupIndexAttr
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
 // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
 // CHECK-NEXT: HLSLSV_GroupIDAttr
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:96 GThreadID 'uint'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
 }
diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
index 4e1f88aa2294b5..a24112c8e1bb8f 100644
--- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl
@@ -49,3 +49,25 @@ struct ST2_GID {
     static uint GID : SV_GroupID;
     uint s_gid : SV_GroupID;
 };
+
+[numthreads(8,8,1)]
+// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain_GThreadID(float ID : SV_GroupThreadID) {
+}
+
+[numthreads(8,8,1)]
+// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
+void CSMain2_GThreadID(ST GID : SV_GroupThreadID) {
+
+}
+
+void foo_GThreadID() {
+// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
+  uint GThreadIS : SV_GroupThreadID;
+}
+
+struct ST2_GThreadID {
+// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}}
+    static uint GThreadID : SV_GroupThreadID;
+    uint s_gthreadid : SV_GroupThreadID;
+};
diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
index 10a5e5dabac87b..6781f9241df240 100644
--- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
+++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl
@@ -49,3 +49,28 @@ void CSMain3_GID(uint3 : SV_GroupID) {
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
 // CHECK-NEXT: HLSLSV_GroupIDAttr
 }
+
+[numthreads(8,8,1)]
+void CSMain_GThreadID(uint ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GThreadID 'void (uint)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:28 ID 'uint'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain1_GThreadID(uint2 ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GThreadID 'void (uint2)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint2'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain2_GThreadID(uint3 ID : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GThreadID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint3'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}
+[numthreads(8,8,1)]
+void CSMain3_GThreadID(uint3 : SV_GroupThreadID) {
+// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GThreadID 'void (uint3)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 'uint3'
+// CHECK-NEXT: HLSLSV_GroupThreadIDAttr
+}

Copy link
Contributor

@inbelic inbelic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider how spirv handles this? Otherwise, LGTM

@lizhengxing
Copy link
Contributor Author

lizhengxing commented Dec 4, 2024

Do we need to consider how spirv handles this? Otherwise, LGTM

This PR is similar to the one for SV_GroupID (#70120). We don't consider the spirv part for both PRs.

@llvm-beanz
Copy link
Collaborator

llvm-beanz commented Dec 4, 2024

This PR is similar to the one for SV_GroupID (#117781). We don't consider the spirv part for both PRs.

I believe that was a mistake, and I don't believe #117781 had sufficient reviews to have merged. It's fine, but I've filed #118700 to complete the outstanding work.

@lizhengxing lizhengxing force-pushed the hlsl-sv-groupthreadid branch from 9d5ffe0 to 28f8234 Compare December 6, 2024 17:41
@lizhengxing lizhengxing changed the title [HLSL] Implement SV_GroupThreadId semantic [HLSL][SPIRV] Implement SV_GroupThreadId semantic Dec 6, 2024
@lizhengxing lizhengxing changed the title [HLSL][SPIRV] Implement SV_GroupThreadId semantic [HLSL] Implement SV_GroupThreadId semantic Dec 6, 2024
Copy link
Contributor

@tex3d tex3d left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good, with minor nit about unnecessary comments.

@lizhengxing
Copy link
Contributor Author

@sudonatalie
Hello Natalie,
Could you help review this PR when you're available since it will modify your changes (#82536), Thanks.

Support SV_GroupThreadId attribute.
Translate it into dx.thread.id.in.group in clang codeGen.

Fixes: llvm#70122
The SV_GroupIndex, SV_DispatchThreadID, SV_GroupID and SV_GroupThreadID are actually legal for meash shader stage. It shouldn't test them with mesh shader.

This commit tests them with vertex shader and move the test into invalid_entry_parameter.hlsl which's a better place for it.
The HLSL SV_GroupThreadID semantic attribute is lowered into @llvm.spv.thread.id.in.group intrinsic in LLVM IR for SPIR-V target.

In the SPIR-V backend, this is now correctly translated to a `LocalInvocationId` builtin variable.

Fixes llvm#70122
@lizhengxing lizhengxing force-pushed the hlsl-sv-groupthreadid branch from 1a3355a to c05f268 Compare December 10, 2024 19:29
@spall spall merged commit 951a284 into llvm:main Dec 10, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:SPIR-V clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[HLSL] implement SV_GroupThreadId semantic
6 participants