-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[HLSL] AST support for WaveSize attribute. #101240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-support @llvm/pr-subscribers-backend-directx Author: Xiang Li (python3kgae) ChangesFirst step for support WaveSize attribute in A new attribute HLSLWaveSizeAttr was supported in the AST. Implement both the wave size and the wave size range, rather than separately which might require more work. For #70118 Patch is 24.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101240.diff 14 Files Affected:
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 46d0a66d59c37..8b2f8358aec28 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4625,6 +4625,22 @@ def HLSLParamModifier : TypeAttr {
let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
}
+def HLSLWaveSize: InheritableAttr {
+ let Spellings = [Microsoft<"WaveSize">];
+ let Args = [IntArgument<"Min">, DefaultIntArgument<"Max", 0>, DefaultIntArgument<"Preferred", 0>];
+ let Subjects = SubjectList<[HLSLEntry]>;
+ let LangOpts = [HLSL];
+ let AdditionalMembers = [{
+ private:
+ int SpelledArgsCount = 0;
+
+ public:
+ void setSpelledArgsCount(int C) { SpelledArgsCount = C; }
+ int getSpelledArgsCount() const { return SpelledArgsCount; }
+ }];
+ let Documentation = [WaveSizeDocs];
+}
+
def RandomizeLayout : InheritableAttr {
let Spellings = [GCC<"randomize_layout">];
let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 4b8d520d73893..e3c98912c81f4 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7322,6 +7322,43 @@ flag.
}];
}
+def WaveSizeDocs : Documentation {
+ let Category = DocCatFunction;
+ let Content = [{
+The ``WaveSize`` attribute specify a wave size on a shader entry point in order
+to indicate either that a shader depends on or strongly prefers a specific wave
+size.
+There're 2 versions of the attribute: ``WaveSize`` and ``RangedWaveSize``.
+The syntax for ``WaveSize`` is:
+
+.. code-block:: text
+
+ ``[WaveSize(<numLanes>)]``
+
+The allowed wave sizes that an HLSL shader may specify are the powers of 2
+between 4 and 128, inclusive.
+In other words, the set: [4, 8, 16, 32, 64, 128].
+
+The syntax for ``RangedWaveSize`` is:
+
+.. code-block:: text
+
+ ``[WaveSize(<minWaveSize>, <maxWaveSize>, [prefWaveSize])]``
+
+Where minWaveSize is the minimum wave size supported by the shader representing
+the beginning of the allowed range, maxWaveSize is the maximum wave size
+supported by the shader representing the end of the allowed range, and
+prefWaveSize is the optional preferred wave size representing the size expected
+to be the most optimal for this shader.
+
+``WaveSize`` is available for HLSL shader model 6.6 and later.
+``RangedWaveSize`` available for HLSL shader model 6.8 and later.
+
+The full documentation is available here: https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
+and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
+ }];
+}
+
def NumThreadsDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
diff --git a/clang/include/clang/Basic/DiagnosticGroups.td b/clang/include/clang/Basic/DiagnosticGroups.td
index 19c3f1e043349..122b95e9f9a2e 100644
--- a/clang/include/clang/Basic/DiagnosticGroups.td
+++ b/clang/include/clang/Basic/DiagnosticGroups.td
@@ -1547,6 +1547,9 @@ def DXILValidation : DiagGroup<"dxil-validation">;
// Warning for HLSL API availability
def HLSLAvailability : DiagGroup<"hlsl-availability">;
+// Warning for HLSL Attributes on Statement.
+def HLSLAttributeStatement : DiagGroup<"attribute-statement">;
+
// Warnings and notes related to const_var_decl_type attribute checks
def ReadOnlyPlacementChecks : DiagGroup<"read-only-types">;
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 581434d33c5c9..9010812837d42 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12361,6 +12361,21 @@ def warn_hlsl_availability_unavailable :
def err_hlsl_export_not_on_function : Error<
"export declaration can only be used on functions">;
+def err_hlsl_attribute_in_wrong_shader_model: Error<
+ "attribute %0 requires shader model %1 or greater">;
+
+def err_hlsl_wavesize_size: Error<
+ "wavesize arguments must be between 4 and 128 and a power of 2">;
+def err_hlsl_wavesize_min_geq_max: Error<
+ "minimum wavesize value %0 must be less than maximum wavesize value %1">;
+def warn_hlsl_wavesize_min_eq_max: Warning<
+ "wave size range minimum and maximum are equal">,
+ InGroup<HLSLAttributeStatement>, DefaultError;
+def err_hlsl_wavesize_pref_size_out_of_range: Error<
+ "preferred wavesize value %0 must be between %1 and %2">;
+def err_hlsl_wavesize_insufficient_shader_model: Error<
+ "wavesize only takes multiple arguments in shader model 6.8 or higher">;
+
// Layout randomization diagnostics.
def err_non_designated_init_used : Error<
"a randomized struct can only be initialized with a designated initializer">;
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 2ddbee67c414b..a4d76818d29d2 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -38,6 +38,9 @@ class SemaHLSL : public SemaBase {
HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
const AttributeCommonInfo &AL, int X,
int Y, int Z);
+ HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
+ int Min, int Max, int Preferred,
+ int SpelledArgsCount);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
@@ -53,6 +56,7 @@ class SemaHLSL : public SemaBase {
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
+ void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 694a754646f27..c9a7c9e54d13c 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2862,6 +2862,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
NT->getZ());
+ else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr))
+ NewAttr =
+ S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(),
+ NT->getPreferred(), NT->getSpelledArgsCount());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 98e3df9083516..57ae83be12881 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6887,6 +6887,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLNumThreads:
S.HLSL().handleNumThreadsAttr(D, AL);
break;
+ case ParsedAttr::AT_HLSLWaveSize:
+ S.HLSL().handleWaveSizeAttr(D, AL);
+ break;
case ParsedAttr::AT_HLSLSV_GroupIndex:
handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9940bc5b4a606..d386897d8251e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -20,7 +20,9 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Frontend/HLSL/HLSLWaveSize.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/TargetParser/Triple.h"
#include <iterator>
@@ -144,6 +146,25 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
}
+HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
+ const AttributeCommonInfo &AL,
+ int Min, int Max, int Preferred,
+ int SpelledArgsCount) {
+ if (HLSLWaveSizeAttr *NT = D->getAttr<HLSLWaveSizeAttr>()) {
+ if (NT->getMin() != Min || NT->getMax() != Max ||
+ NT->getPreferred() != Preferred ||
+ NT->getSpelledArgsCount() != SpelledArgsCount) {
+ Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+ Diag(AL.getLoc(), diag::note_conflicting_attribute);
+ }
+ return nullptr;
+ }
+ HLSLWaveSizeAttr *Result = ::new (getASTContext())
+ HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
+ Result->setSpelledArgsCount(SpelledArgsCount);
+ return Result;
+}
+
HLSLShaderAttr *
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType) {
@@ -215,7 +236,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
-
+ auto &TargetInfo = getASTContext().getTargetInfo();
+ VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
switch (ST) {
case llvm::Triple::Pixel:
case llvm::Triple::Vertex:
@@ -235,6 +257,13 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
llvm::Triple::Mesh});
FD->setInvalidDecl();
}
+ if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
+ DiagnoseAttrStageMismatch(NT, ST,
+ {llvm::Triple::Compute,
+ llvm::Triple::Amplification,
+ llvm::Triple::Mesh});
+ FD->setInvalidDecl();
+ }
break;
case llvm::Triple::Compute:
@@ -245,6 +274,20 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
<< llvm::Triple::getEnvironmentTypeName(ST);
FD->setInvalidDecl();
}
+ if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) {
+ if (Ver.getMajor() < 6u ||
+ (Ver.getMajor() == 6u && Ver.getMinor() < 6u)) {
+ Diag(NT->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model)
+ << "wavesize"
+ << "6.6";
+ FD->setInvalidDecl();
+ } else if (NT->getSpelledArgsCount() > 1 &&
+ (Ver.getMajor() == 6u && Ver.getMinor() < 8u)) {
+ Diag(NT->getLocation(),
+ diag::err_hlsl_wavesize_insufficient_shader_model);
+ FD->setInvalidDecl();
+ }
+ }
break;
default:
llvm_unreachable("Unhandled environment in triple");
@@ -348,6 +391,77 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}
+void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
+ // validate that the wavesize argument is a power of 2 between 4 and 128
+ // inclusive
+ unsigned SpelledArgsCount = AL.getNumArgs();
+ if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
+ return;
+
+ uint32_t Min;
+ if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min))
+ return;
+
+ uint32_t Max = 0;
+ if (SpelledArgsCount > 1 &&
+ !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max))
+ return;
+
+ uint32_t Preferred = 0;
+ if (SpelledArgsCount > 2 &&
+ !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred))
+ return;
+ llvm::hlsl::WaveSize WaveSize(Min, Max, Preferred);
+ llvm::hlsl::WaveSize::ValidationResult ValidationResult = WaveSize.validate();
+ // WaveSize validation succeeds when not defined, but since we have an
+ // attribute, this means min was zero, which is invalid for min.
+ if (ValidationResult == llvm::hlsl::WaveSize::ValidationResult::Success &&
+ !WaveSize.isDefined())
+ ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMin;
+
+ // It is invalid to explicitly specify degenerate cases.
+ if (SpelledArgsCount > 1 && WaveSize.Max == 0)
+ ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMax;
+ else if (SpelledArgsCount > 2 && WaveSize.Preferred == 0)
+ ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred;
+
+ switch (ValidationResult) {
+ case llvm::hlsl::WaveSize::ValidationResult::Success:
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::InvalidMin:
+ case llvm::hlsl::WaveSize::ValidationResult::InvalidMax:
+ case llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred:
+ case llvm::hlsl::WaveSize::ValidationResult::NoRangeOrMin:
+ Diag(AL.getLoc(), diag::err_hlsl_wavesize_size)
+ << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::MaxEqualsMin:
+ Diag(AL.getLoc(), diag::warn_hlsl_wavesize_min_eq_max)
+ << WaveSize.Min << WaveSize.Max;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::MaxLessThanMin:
+ Diag(AL.getLoc(), diag::err_hlsl_wavesize_min_geq_max)
+ << WaveSize.Min << WaveSize.Max;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::PreferredOutOfRange:
+ Diag(AL.getLoc(), diag::err_hlsl_wavesize_pref_size_out_of_range)
+ << WaveSize.Preferred << WaveSize.Min << WaveSize.Max;
+ break;
+ case llvm::hlsl::WaveSize::ValidationResult::MaxOrPreferredWhenUndefined:
+ case llvm::hlsl::WaveSize::ValidationResult::PreferredWhenNoRange:
+ llvm_unreachable("Should have hit InvalidMax or InvalidPreferred instead.");
+ break;
+ }
+
+ if (ValidationResult != llvm::hlsl::WaveSize::ValidationResult::Success)
+ return;
+
+ HLSLWaveSizeAttr *NewAttr =
+ mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
+ if (NewAttr)
+ D->addAttr(NewAttr);
+}
+
static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
if (!T->hasUnsignedIntegerRepresentation())
return false;
@@ -356,7 +470,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
return true;
}
-void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
+void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
auto *VD = cast<ValueDecl>(D);
if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
diff --git a/clang/test/AST/HLSL/WaveSize.hlsl b/clang/test/AST/HLSL/WaveSize.hlsl
new file mode 100644
index 0000000000000..fd6dc7c94d6d0
--- /dev/null
+++ b/clang/test/AST/HLSL/WaveSize.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w0 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 128 0 0
+ [numthreads(8,8,1)]
+ [WaveSize(128)]
+ void w0() {
+ }
+
+
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w1 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 64 0
+ [numthreads(8,8,1)]
+ [WaveSize(8, 64)]
+ void w1() {
+ }
+
+
+// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w2 'void ()'
+// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64
+ [numthreads(8,8,1)]
+ [WaveSize(8, 128, 64)]
+ void w2() {
+ }
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
new file mode 100644
index 0000000000000..10c562839eef6
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl
@@ -0,0 +1,101 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl %s -verify
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(1)]
+void e0() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 2)]
+void e1() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 8, 7)]
+void e2() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{minimum wavesize value 16 must be less than maximum wavesize value 8}}
+[WaveSize(16, 8)]
+void e3() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{preferred wavesize value 8 must be between 16 and 128}}
+[WaveSize(16, 128, 8)]
+void e4() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{preferred wavesize value 32 must be between 8 and 16}}
+[WaveSize(8, 16, 32)]
+void e5() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 0)]
+void e6() {
+}
+
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(4, 4, 0)]
+void e7() {
+}
+
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wave size range minimum and maximum are equal}}
+[WaveSize(16, 16)]
+void e8() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(0)]
+void e9() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}}
+[WaveSize(-4)]
+void e10() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{'WaveSize' attribute takes no more than 3 arguments}}
+[WaveSize(16, 128, 64, 64)]
+void e11() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{'WaveSize' attribute takes at least 1 argument}}
+[WaveSize()]
+void e12() {
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+// expected-error@+1 {{'WaveSize' attribute takes at least 1 argument}}
+[WaveSize]
+void e13() {
+}
diff --git a/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
new file mode 100644
index 0000000000000..13e27a5c4b685
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl
@@ -0,0 +1,20 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-pixel -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-vertex -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-geometry -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-hull -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-domain -x hlsl %s -verify
+
+#if __SHADER_TARGET_STAGE == __SHADER_STAGE_PIXEL
+// expected-error@+10 {{attribute 'WaveSize' is unsupported in 'pixel' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_VERTEX
+// expected-error@+8 {{attribute 'WaveSize' is unsupported in 'vertex' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_GEOMETRY
+// expected-error@+6 {{attribute 'WaveSize' is unsupported in 'geometry' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_HULL
+// expected-error@+4 {{attribute 'WaveSize' is unsupported in 'hull' shaders, requires one of the following: compute, amplification, mesh}}
+#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_DOMAIN
+// expected-error@+2 {{attribute 'WaveSize' is unsupported in 'domain' shaders, requires one of the following: compute, amplification, mesh}}
+#endif
+[WaveSize(16)]
+void main() {
+}
diff --git a/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
new file mode 100644
index 0000000000000..fb9978c6ce3ce
--- /dev/null
+++ b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl
@@ -0,0 +1,24 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -x hlsl %s -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.5-library -x hlsl %s -verify
+
+[shader("compute")]
+[numthreads(1,1,1)]
+#if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6
+// expected-error@+4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}}
+#elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5
+// expected-e...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
clang/lib/Sema/SemaHLSL.cpp
Outdated
!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred)) | ||
return; | ||
llvm::hlsl::WaveSize WaveSize(Min, Max, Preferred); | ||
llvm::hlsl::WaveSize::ValidationResult ValidationResult = WaveSize.validate(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do the WaveSize.validate()
function and the ValidationResult
enum gain us here? The logic below pretty much repeats everything that validate did, just checking enums that are named after the various conditions rather than the conditions themselves. Would it be simpler to just implement the logic here and keep it out of the WaveSize object entirely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need to validate WaveSize at llvm IR level as well.
WaveSize::validate could be shared between Sema and DirectX backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validation code looks way too complicated and way too hard to follow. I don't think the code sharing here gets us anything and it seems to come at a cost of a lot of extra complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
"attribute %0 requires shader model %1 or greater">; | ||
|
||
def err_hlsl_wavesize_size: Error< | ||
"wavesize arguments must be between 4 and 128 and a power of 2">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we instead use the existing diagnostic err_attribute_power_of_two_in_range
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
"wave size range minimum and maximum are equal">, | ||
InGroup<HLSLAttributeStatement>, DefaultError; | ||
def err_hlsl_wavesize_pref_size_out_of_range: Error< | ||
"preferred wavesize value %0 must be between %1 and %2">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this also just use err_attribute_power_of_two_in_range
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
def err_hlsl_wavesize_size: Error< | ||
"wavesize arguments must be between 4 and 128 and a power of 2">; | ||
def err_hlsl_wavesize_min_geq_max: Error< | ||
"minimum wavesize value %0 must be less than maximum wavesize value %1">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about err_attribute_argument_invalid
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
def warn_hlsl_wavesize_min_eq_max: Warning< | ||
"wave size range minimum and maximum are equal">, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this more generic and match other attribute formats?
def warn_hlsl_wavesize_min_eq_max: Warning< | |
"wave size range minimum and maximum are equal">, | |
def warn_attr_min_eq_max: Warning< | |
"%0 attribute minimum and maximum arguments are equal">, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
def err_hlsl_wavesize_pref_size_out_of_range: Error< | ||
"preferred wavesize value %0 must be between %1 and %2">; | ||
def err_hlsl_wavesize_insufficient_shader_model: Error< | ||
"wavesize only takes multiple arguments in shader model 6.8 or higher">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should just use err_attribute_wrong_number_arguments
, and we should have a more generic note to provide context that SM 6.8 adds support for the 3-argument variant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to
def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
"attribute %0 with %1 arguments requires shader model %2 or greater">;
clang/lib/Sema/SemaDecl.cpp
Outdated
@@ -2862,6 +2862,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D, | |||
else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr)) | |||
NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), | |||
NT->getZ()); | |||
else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code block looks ripe for a refactoring to use a switch to detect the attribute type... We should consider doing that cleanup in a follow-up change.
clang/lib/Sema/SemaHLSL.cpp
Outdated
if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) { | ||
if (Ver < VersionTuple(6, 6)) { | ||
Diag(NT->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model) | ||
<< "wavesize" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to put the attribute declaration here rather than the string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
clang/lib/Sema/SemaHLSL.cpp
Outdated
@@ -245,6 +274,18 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { | |||
<< llvm::Triple::getEnvironmentTypeName(ST); | |||
FD->setInvalidDecl(); | |||
} | |||
if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is NT the name here? That doesn't seem to make sense with the context of what you're assigning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed.
clang/lib/Sema/SemaHLSL.cpp
Outdated
!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred)) | ||
return; | ||
llvm::hlsl::WaveSize WaveSize(Min, Max, Preferred); | ||
llvm::hlsl::WaveSize::ValidationResult ValidationResult = WaveSize.validate(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validation code looks way too complicated and way too hard to follow. I don't think the code sharing here gets us anything and it seems to come at a cost of a lot of extra complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than my nits, looks good to me
First step for support WaveSize attribute in https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html A new attribute HLSLWaveSizeAttr was supported in the AST. Implement both the wave size and the wave size range, rather than separately which might require more work. For llvm#70118
#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_DOMAIN | ||
// expected-error@+2 {{attribute 'WaveSize' is unsupported in 'domain' shaders, requires one of the following: compute, amplification, mesh}} | ||
#endif | ||
[WaveSize(16)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you put something like // #WaveSize
at the end of the line here, the checks above can all be @#WaveSize
instead of line offsets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
clang/lib/Sema/SemaHLSL.cpp
Outdated
@@ -357,6 +398,74 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) { | |||
D->addAttr(NewAttr); | |||
} | |||
|
|||
static bool isValidWaveSizeValue(unsigned Value) { | |||
return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we instead use llvm::isPowerOf2_32
? That is clearer to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor suggestions, but otherwise looks good.
clang/lib/Sema/SemaHLSL.cpp
Outdated
const AttributeCommonInfo &AL, | ||
int Min, int Max, int Preferred, | ||
int SpelledArgsCount) { | ||
if (HLSLWaveSizeAttr *NT = D->getAttr<HLSLWaveSizeAttr>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more "NT" to rename.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/30/builds/5200 Here is the relevant piece of the build log for the reference
|
First step for support WaveSize attribute in
https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html
and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
A new attribute HLSLWaveSizeAttr was supported in the AST.
Implement both the wave size and the wave size range, rather than separately which might require more work.
For #70118