Skip to content

Commit 94189b4

Browse files
committed
[HLSL] Fix MSFT Attribute parsing, add numthreads
HLSL uses Microsoft-style attributes `[attr]`, which clang mostly ignores. For HLSL we need to handle known Microsoft attributes, and to maintain C/C++ as-is we ignore unknown attributes. To utilize this new code path, this change adds the HLSL `numthreads` attribute. Reviewed By: rnk Differential Revision: https://reviews.llvm.org/D122627
1 parent fe8b223 commit 94189b4

File tree

9 files changed

+170
-3
lines changed

9 files changed

+170
-3
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def ObjCAutoRefCount : LangOpt<"ObjCAutoRefCount">;
336336
def ObjCNonFragileRuntime
337337
: LangOpt<"", "LangOpts.ObjCRuntime.allowsClassStubs()">;
338338

339+
def HLSL : LangOpt<"HLSL">;
340+
339341
// Language option for CMSE extensions
340342
def Cmse : LangOpt<"Cmse">;
341343

@@ -3937,3 +3939,11 @@ def Error : InheritableAttr {
39373939
let Subjects = SubjectList<[Function], ErrorDiag>;
39383940
let Documentation = [ErrorAttrDocs];
39393941
}
3942+
3943+
def HLSLNumThreads: InheritableAttr {
3944+
let Spellings = [Microsoft<"numthreads">];
3945+
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
3946+
let Subjects = SubjectList<[Function]>;
3947+
let LangOpts = [HLSL];
3948+
let Documentation = [NumThreadsDocs];
3949+
}

clang/include/clang/Basic/AttrDocs.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6368,3 +6368,14 @@ flag.
63686368
.. _Return-Oriented Programming: https://en.wikipedia.org/wiki/Return-oriented_programming
63696369
}];
63706370
}
6371+
6372+
def NumThreadsDocs : Documentation {
6373+
let Category = DocCatFunction;
6374+
let Content = [{
6375+
The ``numthreads`` attribute applies to HLSL shaders where explcit thread counts
6376+
are required. The ``X``, ``Y``, and ``Z`` values provided to the attribute
6377+
dictate the thread id. Total number of threads executed is ``X * Y * Z``.
6378+
6379+
The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-numthreads
6380+
}];
6381+
}

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11564,4 +11564,12 @@ def err_std_source_location_impl_not_found : Error<
1156411564
"'std::source_location::__impl' was not found; it must be defined before '__builtin_source_location' is called">;
1156511565
def err_std_source_location_impl_malformed : Error<
1156611566
"'std::source_location::__impl' must be standard-layout and have only two 'const char *' fields '_M_file_name' and '_M_function_name', and two integral fields '_M_line' and '_M_column'">;
11567+
11568+
// HLSL Diagnostics
11569+
def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in %select{Pixel|Vertex|Geometry|Hull|Domain|Compute|Library|RayGeneration|Intersection|AnyHit|ClosestHit|Miss|Callable|Mesh|Amplification|Invalid}1 shaders, requires %2">;
11570+
11571+
def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">;
11572+
def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">;
11573+
1156711574
} // end of sema component.
11575+

clang/include/clang/Parse/Parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2783,7 +2783,8 @@ class Parser : public CodeCompletionHandler {
27832783
const IdentifierInfo *EnclosingScope = nullptr);
27842784

27852785
void MaybeParseMicrosoftAttributes(ParsedAttributes &Attrs) {
2786-
if (getLangOpts().MicrosoftExt && Tok.is(tok::l_square)) {
2786+
if ((getLangOpts().MicrosoftExt || getLangOpts().HLSL) &&
2787+
Tok.is(tok::l_square)) {
27872788
ParsedAttributes AttrsWithRange(AttrFactory);
27882789
ParseMicrosoftAttributes(AttrsWithRange);
27892790
Attrs.takeAllFrom(AttrsWithRange);

clang/lib/Parse/ParseDeclCXX.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4302,10 +4302,19 @@ bool Parser::ParseCXX11AttributeArgs(IdentifierInfo *AttrName,
43024302
ParsedAttr::Syntax Syntax =
43034303
LO.CPlusPlus ? ParsedAttr::AS_CXX11 : ParsedAttr::AS_C2x;
43044304

4305+
// Try parsing microsoft attributes
4306+
if (getLangOpts().MicrosoftExt || getLangOpts().HLSL) {
4307+
if (hasAttribute(AttrSyntax::Microsoft, ScopeName, AttrName,
4308+
getTargetInfo(), getLangOpts()))
4309+
Syntax = ParsedAttr::AS_Microsoft;
4310+
}
4311+
43054312
// If the attribute isn't known, we will not attempt to parse any
43064313
// arguments.
4307-
if (!hasAttribute(LO.CPlusPlus ? AttrSyntax::CXX : AttrSyntax::C, ScopeName,
4314+
if (Syntax != ParsedAttr::AS_Microsoft &&
4315+
!hasAttribute(LO.CPlusPlus ? AttrSyntax::CXX : AttrSyntax::C, ScopeName,
43084316
AttrName, getTargetInfo(), getLangOpts())) {
4317+
if (getLangOpts().MicrosoftExt || getLangOpts().HLSL) {}
43094318
// Eat the left paren, then skip to the ending right paren.
43104319
ConsumeParen();
43114320
SkipUntil(tok::r_paren);
@@ -4688,8 +4697,17 @@ void Parser::ParseMicrosoftAttributes(ParsedAttributes &Attrs) {
46884697
break;
46894698
if (Tok.getIdentifierInfo()->getName() == "uuid")
46904699
ParseMicrosoftUuidAttributeArgs(Attrs);
4691-
else
4700+
else {
4701+
IdentifierInfo *II = Tok.getIdentifierInfo();
4702+
SourceLocation NameLoc = Tok.getLocation();
46924703
ConsumeToken();
4704+
if (Tok.is(tok::l_paren)) {
4705+
CachedTokens OpenMPTokens;
4706+
ParseCXX11AttributeArgs(II, NameLoc, Attrs, &EndLoc, nullptr,
4707+
SourceLocation(), OpenMPTokens);
4708+
ReplayOpenMPAttributeTokens(OpenMPTokens);
4709+
} // FIXME: handle attributes that don't have arguments
4710+
}
46934711
}
46944712

46954713
T.consumeClose();

clang/lib/Sema/SemaDecl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11323,6 +11323,11 @@ void Sema::CheckMain(FunctionDecl* FD, const DeclSpec& DS) {
1132311323
return;
1132411324
}
1132511325

11326+
// Functions named main in hlsl are default entries, but don't have specific
11327+
// signatures they are required to conform to.
11328+
if (getLangOpts().HLSL)
11329+
return;
11330+
1132611331
QualType T = FD->getType();
1132711332
assert(T->isFunctionType() && "function decl is not of function type");
1132811333
const FunctionType* FT = T->castAs<FunctionType>();

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "clang/AST/Type.h"
2525
#include "clang/Basic/CharInfo.h"
2626
#include "clang/Basic/DarwinSDKInfo.h"
27+
#include "clang/Basic/LangOptions.h"
2728
#include "clang/Basic/SourceLocation.h"
2829
#include "clang/Basic/SourceManager.h"
2930
#include "clang/Basic/TargetBuiltins.h"
@@ -6836,6 +6837,64 @@ static void handleUuidAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
68366837
D->addAttr(UA);
68376838
}
68386839

6840+
static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
6841+
using llvm::Triple;
6842+
Triple Target = S.Context.getTargetInfo().getTriple();
6843+
if (!llvm::is_contained({Triple::Compute, Triple::Mesh, Triple::Amplification,
6844+
Triple::Library},
6845+
Target.getEnvironment())) {
6846+
uint32_t Pipeline =
6847+
(uint32_t)S.Context.getTargetInfo().getTriple().getEnvironment() -
6848+
(uint32_t)llvm::Triple::Pixel;
6849+
S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
6850+
<< AL << Pipeline << "Compute, Amplification, Mesh or Library";
6851+
return;
6852+
}
6853+
6854+
llvm::VersionTuple SMVersion = Target.getOSVersion();
6855+
uint32_t ZMax = 1024;
6856+
uint32_t ThreadMax = 1024;
6857+
if (SMVersion.getMajor() <= 4) {
6858+
ZMax = 1;
6859+
ThreadMax = 768;
6860+
} else if (SMVersion.getMajor() == 5) {
6861+
ZMax = 64;
6862+
ThreadMax = 1024;
6863+
}
6864+
6865+
uint32_t X;
6866+
if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(0), X))
6867+
return;
6868+
if (X > 1024) {
6869+
S.Diag(AL.getArgAsExpr(0)->getExprLoc(),
6870+
diag::err_hlsl_numthreads_argument_oor) << 0 << 1024;
6871+
return;
6872+
}
6873+
uint32_t Y;
6874+
if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(1), Y))
6875+
return;
6876+
if (Y > 1024) {
6877+
S.Diag(AL.getArgAsExpr(1)->getExprLoc(),
6878+
diag::err_hlsl_numthreads_argument_oor) << 1 << 1024;
6879+
return;
6880+
}
6881+
uint32_t Z;
6882+
if (!checkUInt32Argument(S, AL, AL.getArgAsExpr(2), Z))
6883+
return;
6884+
if (Z > ZMax) {
6885+
S.Diag(AL.getArgAsExpr(2)->getExprLoc(),
6886+
diag::err_hlsl_numthreads_argument_oor) << 2 << ZMax;
6887+
return;
6888+
}
6889+
6890+
if (X * Y * Z > ThreadMax) {
6891+
S.Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;
6892+
return;
6893+
}
6894+
6895+
D->addAttr(::new (S.Context) HLSLNumThreadsAttr(S.Context, AL, X, Y, Z));
6896+
}
6897+
68396898
static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
68406899
if (!S.LangOpts.CPlusPlus) {
68416900
S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
@@ -8697,6 +8756,11 @@ static void ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D,
86978756
case ParsedAttr::AT_Thread:
86988757
handleDeclspecThreadAttr(S, D, AL);
86998758
break;
8759+
8760+
// HLSL attributes:
8761+
case ParsedAttr::AT_HLSLNumThreads:
8762+
handleHLSLNumThreadsAttr(S, D, AL);
8763+
break;
87008764

87018765
case ParsedAttr::AT_AbiTag:
87028766
handleAbiTagAttr(S, D, AL);

clang/test/SemaHLSL/lit.local.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
config.suffixes = ['.hlsl']

clang/test/SemaHLSL/num_threads.hlsl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
2+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -o - %s | FileCheck %s
3+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-amplification -x hlsl -ast-dump -o - %s | FileCheck %s
4+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -o - %s | FileCheck %s
5+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-pixel -x hlsl -ast-dump -o - %s -verify
6+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-vertex -x hlsl -ast-dump -o - %s -verify
7+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-hull -x hlsl -ast-dump -o - %s -verify
8+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-domain -x hlsl -ast-dump -o - %s -verify
9+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify
10+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel5.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify
11+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel4.0-compute -x hlsl -ast-dump -o - %s -DFAIL -verify
12+
13+
#if __SHADER_TARGET_STAGE == __SHADER_STAGE_COMPUTE || __SHADER_TARGET_STAGE == __SHADER_STAGE_MESH || __SHADER_TARGET_STAGE == __SHADER_STAGE_AMPLIFICATION || __SHADER_TARGET_STAGE == __SHADER_STAGE_LIBRARY
14+
#ifdef FAIL
15+
#if __SHADER_TARGET_MAJOR == 6
16+
// expected-error@+1 {{'numthreads' attribute requires an integer constant}}
17+
[numthreads("1",2,3)]
18+
// expected-error@+1 {{argument 'X' to numthreads attribute cannot exceed 1024}}
19+
[numthreads(-1,2,3)]
20+
// expected-error@+1 {{argument 'Y' to numthreads attribute cannot exceed 1024}}
21+
[numthreads(1,-2,3)]
22+
// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 1024}}
23+
[numthreads(1,2,-3)]
24+
// expected-error@+1 {{total number of threads cannot exceed 1024}}
25+
[numthreads(1024,1024,1024)]
26+
#elif __SHADER_TARGET_MAJOR == 5
27+
// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 64}}
28+
[numthreads(1,2,68)]
29+
#else
30+
// expected-error@+1 {{argument 'Z' to numthreads attribute cannot exceed 1}}
31+
[numthreads(1,2,2)]
32+
// expected-error@+1 {{total number of threads cannot exceed 768}}
33+
[numthreads(1024,1,1)]
34+
#endif
35+
#endif
36+
// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:{{[0-9]+}}:2, col:18> 1 2 1
37+
[numthreads(1,2,1)]
38+
int entry() {
39+
return 1;
40+
}
41+
#else
42+
// expected-error-re@+1 {{attribute 'numthreads' is unsupported in {{[A-Za-z]+}} shaders, requires Compute, Amplification, Mesh or Library}}
43+
[numthreads(1,1,1)]
44+
int main() {
45+
return 1;
46+
}
47+
#endif
48+
49+

0 commit comments

Comments
 (0)