Skip to content

Commit e41579a

Browse files
authored
[HLSL] AST support for WaveSize attribute. (#101240)
First step for support WaveSize attribute in https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html A new attribute HLSLWaveSizeAttr was supported in the AST. Implement both the wave size and the wave size range, rather than separately which might require more work. For #70118
1 parent a3ea90f commit e41579a

File tree

13 files changed

+373
-1
lines changed

13 files changed

+373
-1
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4651,6 +4651,22 @@ def HLSLParamModifier : TypeAttr {
46514651
let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
46524652
}
46534653

4654+
def HLSLWaveSize: InheritableAttr {
4655+
let Spellings = [Microsoft<"WaveSize">];
4656+
let Args = [IntArgument<"Min">, DefaultIntArgument<"Max", 0>, DefaultIntArgument<"Preferred", 0>];
4657+
let Subjects = SubjectList<[HLSLEntry]>;
4658+
let LangOpts = [HLSL];
4659+
let AdditionalMembers = [{
4660+
private:
4661+
int SpelledArgsCount = 0;
4662+
4663+
public:
4664+
void setSpelledArgsCount(int C) { SpelledArgsCount = C; }
4665+
int getSpelledArgsCount() const { return SpelledArgsCount; }
4666+
}];
4667+
let Documentation = [WaveSizeDocs];
4668+
}
4669+
46544670
def RandomizeLayout : InheritableAttr {
46554671
let Spellings = [GCC<"randomize_layout">];
46564672
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7421,6 +7421,43 @@ flag.
74217421
}];
74227422
}
74237423

7424+
def WaveSizeDocs : Documentation {
7425+
let Category = DocCatFunction;
7426+
let Content = [{
7427+
The ``WaveSize`` attribute specify a wave size on a shader entry point in order
7428+
to indicate either that a shader depends on or strongly prefers a specific wave
7429+
size.
7430+
There're 2 versions of the attribute: ``WaveSize`` and ``RangedWaveSize``.
7431+
The syntax for ``WaveSize`` is:
7432+
7433+
.. code-block:: text
7434+
7435+
``[WaveSize(<numLanes>)]``
7436+
7437+
The allowed wave sizes that an HLSL shader may specify are the powers of 2
7438+
between 4 and 128, inclusive.
7439+
In other words, the set: [4, 8, 16, 32, 64, 128].
7440+
7441+
The syntax for ``RangedWaveSize`` is:
7442+
7443+
.. code-block:: text
7444+
7445+
``[WaveSize(<minWaveSize>, <maxWaveSize>, [prefWaveSize])]``
7446+
7447+
Where minWaveSize is the minimum wave size supported by the shader representing
7448+
the beginning of the allowed range, maxWaveSize is the maximum wave size
7449+
supported by the shader representing the end of the allowed range, and
7450+
prefWaveSize is the optional preferred wave size representing the size expected
7451+
to be the most optimal for this shader.
7452+
7453+
``WaveSize`` is available for HLSL shader model 6.6 and later.
7454+
``RangedWaveSize`` available for HLSL shader model 6.8 and later.
7455+
7456+
The full documentation is available here: https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
7457+
and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
7458+
}];
7459+
}
7460+
74247461
def NumThreadsDocs : Documentation {
74257462
let Category = DocCatFunction;
74267463
let Content = [{

clang/include/clang/Basic/DiagnosticGroups.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,9 @@ def HLSLAvailability : DiagGroup<"hlsl-availability">;
15501550
// Warnings for legacy binding behavior
15511551
def LegacyConstantRegisterBinding : DiagGroup<"legacy-constant-register-binding">;
15521552

1553+
// Warning for HLSL Attributes on Statement.
1554+
def HLSLAttributeStatement : DiagGroup<"attribute-statement">;
1555+
15531556
// Warnings and notes related to const_var_decl_type attribute checks
15541557
def ReadOnlyPlacementChecks : DiagGroup<"read-only-types">;
15551558

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12384,6 +12384,16 @@ def warn_hlsl_availability_unavailable :
1238412384
def err_hlsl_export_not_on_function : Error<
1238512385
"export declaration can only be used on functions">;
1238612386

12387+
def err_hlsl_attribute_in_wrong_shader_model: Error<
12388+
"attribute %0 requires shader model %1 or greater">;
12389+
12390+
def warn_attr_min_eq_max: Warning<
12391+
"%0 attribute minimum and maximum arguments are equal">,
12392+
InGroup<HLSLAttributeStatement>, DefaultError;
12393+
12394+
def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
12395+
"attribute %0 with %1 arguments requires shader model %2 or greater">;
12396+
1238712397
// Layout randomization diagnostics.
1238812398
def err_non_designated_init_used : Error<
1238912399
"a randomized struct can only be initialized with a designated initializer">;

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class SemaHLSL : public SemaBase {
3737
HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
3838
const AttributeCommonInfo &AL, int X,
3939
int Y, int Z);
40+
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
41+
int Min, int Max, int Preferred,
42+
int SpelledArgsCount);
4043
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
4144
llvm::Triple::EnvironmentType ShaderType);
4245
HLSLParamModifierAttr *
@@ -52,6 +55,7 @@ class SemaHLSL : public SemaBase {
5255
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
5356

5457
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
58+
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
5559
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
5660
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
5761
void handleShaderAttr(Decl *D, const ParsedAttr &AL);

clang/lib/Sema/SemaDecl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,6 +2863,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
28632863
else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
28642864
NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
28652865
NT->getZ());
2866+
else if (const auto *WS = dyn_cast<HLSLWaveSizeAttr>(Attr))
2867+
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
2868+
WS->getPreferred(),
2869+
WS->getSpelledArgsCount());
28662870
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
28672871
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
28682872
else if (isa<SuppressAttr>(Attr))

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6886,6 +6886,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
68866886
case ParsedAttr::AT_HLSLNumThreads:
68876887
S.HLSL().handleNumThreadsAttr(D, AL);
68886888
break;
6889+
case ParsedAttr::AT_HLSLWaveSize:
6890+
S.HLSL().handleWaveSizeAttr(D, AL);
6891+
break;
68896892
case ParsedAttr::AT_HLSLSV_GroupIndex:
68906893
handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
68916894
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/StringExtras.h"
2323
#include "llvm/ADT/StringRef.h"
2424
#include "llvm/Support/Casting.h"
25+
#include "llvm/Support/DXILABI.h"
2526
#include "llvm/Support/ErrorHandling.h"
2627
#include "llvm/TargetParser/Triple.h"
2728
#include <iterator>
@@ -153,6 +154,25 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
153154
HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
154155
}
155156

157+
HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
158+
const AttributeCommonInfo &AL,
159+
int Min, int Max, int Preferred,
160+
int SpelledArgsCount) {
161+
if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
162+
if (WS->getMin() != Min || WS->getMax() != Max ||
163+
WS->getPreferred() != Preferred ||
164+
WS->getSpelledArgsCount() != SpelledArgsCount) {
165+
Diag(WS->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
166+
Diag(AL.getLoc(), diag::note_conflicting_attribute);
167+
}
168+
return nullptr;
169+
}
170+
HLSLWaveSizeAttr *Result = ::new (getASTContext())
171+
HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
172+
Result->setSpelledArgsCount(SpelledArgsCount);
173+
return Result;
174+
}
175+
156176
HLSLShaderAttr *
157177
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
158178
llvm::Triple::EnvironmentType ShaderType) {
@@ -224,7 +244,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
224244
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
225245
assert(ShaderAttr && "Entry point has no shader attribute");
226246
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
227-
247+
auto &TargetInfo = getASTContext().getTargetInfo();
248+
VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
228249
switch (ST) {
229250
case llvm::Triple::Pixel:
230251
case llvm::Triple::Vertex:
@@ -244,6 +265,13 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
244265
llvm::Triple::Mesh});
245266
FD->setInvalidDecl();
246267
}
268+
if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
269+
DiagnoseAttrStageMismatch(WS, ST,
270+
{llvm::Triple::Compute,
271+
llvm::Triple::Amplification,
272+
llvm::Triple::Mesh});
273+
FD->setInvalidDecl();
274+
}
247275
break;
248276

249277
case llvm::Triple::Compute:
@@ -254,6 +282,19 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
254282
<< llvm::Triple::getEnvironmentTypeName(ST);
255283
FD->setInvalidDecl();
256284
}
285+
if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
286+
if (Ver < VersionTuple(6, 6)) {
287+
Diag(WS->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
288+
<< WS << "6.6";
289+
FD->setInvalidDecl();
290+
} else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
291+
Diag(
292+
WS->getLocation(),
293+
diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
294+
<< WS << WS->getSpelledArgsCount() << "6.8";
295+
FD->setInvalidDecl();
296+
}
297+
}
257298
break;
258299
default:
259300
llvm_unreachable("Unhandled environment in triple");
@@ -357,6 +398,74 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
357398
D->addAttr(NewAttr);
358399
}
359400

401+
static bool isValidWaveSizeValue(unsigned Value) {
402+
return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
403+
}
404+
405+
void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
406+
// validate that the wavesize argument is a power of 2 between 4 and 128
407+
// inclusive
408+
unsigned SpelledArgsCount = AL.getNumArgs();
409+
if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
410+
return;
411+
412+
uint32_t Min;
413+
if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
414+
return;
415+
416+
uint32_t Max = 0;
417+
if (SpelledArgsCount > 1 &&
418+
!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
419+
return;
420+
421+
uint32_t Preferred = 0;
422+
if (SpelledArgsCount > 2 &&
423+
!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
424+
return;
425+
426+
if (SpelledArgsCount > 2) {
427+
if (!isValidWaveSizeValue(Preferred)) {
428+
Diag(AL.getArgAsExpr(2)->getExprLoc(),
429+
diag::err_attribute_power_of_two_in_range)
430+
<< AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
431+
<< Preferred;
432+
return;
433+
}
434+
// Preferred not in range.
435+
if (Preferred < Min || Preferred > Max) {
436+
Diag(AL.getArgAsExpr(2)->getExprLoc(),
437+
diag::err_attribute_power_of_two_in_range)
438+
<< AL << Min << Max << Preferred;
439+
return;
440+
}
441+
} else if (SpelledArgsCount > 1) {
442+
if (!isValidWaveSizeValue(Max)) {
443+
Diag(AL.getArgAsExpr(1)->getExprLoc(),
444+
diag::err_attribute_power_of_two_in_range)
445+
<< AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
446+
return;
447+
}
448+
if (Max < Min) {
449+
Diag(AL.getLoc(), diag::err_attribute_argument_invalid) << AL << 1;
450+
return;
451+
} else if (Max == Min) {
452+
Diag(AL.getLoc(), diag::warn_attr_min_eq_max) << AL;
453+
}
454+
} else {
455+
if (!isValidWaveSizeValue(Min)) {
456+
Diag(AL.getArgAsExpr(0)->getExprLoc(),
457+
diag::err_attribute_power_of_two_in_range)
458+
<< AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
459+
return;
460+
}
461+
}
462+
463+
HLSLWaveSizeAttr *NewAttr =
464+
mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
465+
if (NewAttr)
466+
D->addAttr(NewAttr);
467+
}
468+
360469
static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
361470
if (!T->hasUnsignedIntegerRepresentation())
362471
return false;

clang/test/AST/HLSL/WaveSize.hlsl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl -ast-dump -o - %s | FileCheck %s
2+
3+
// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w0 'void ()'
4+
// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 128 0 0
5+
[numthreads(8,8,1)]
6+
[WaveSize(128)]
7+
void w0() {
8+
}
9+
10+
11+
12+
// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w1 'void ()'
13+
// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 64 0
14+
[numthreads(8,8,1)]
15+
[WaveSize(8, 64)]
16+
void w1() {
17+
}
18+
19+
20+
// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w2 'void ()'
21+
// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64
22+
// Duplicate WaveSize attribute will be ignored.
23+
// CHECK-NOT:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64
24+
[numthreads(8,8,1)]
25+
[WaveSize(8, 128, 64)]
26+
[WaveSize(8, 128, 64)]
27+
void w2() {
28+
}

0 commit comments

Comments
 (0)