Skip to content

Commit b3963d3

Browse files
authored
[HLSL][RootSignature] Add parsing for RootFlags (#138055)
- defines the `RootFlags` in-memory enum - defines `parseRootFlags` to parse the various flag enums into a single `uint32_t` - adds corresponding unit tests - improves the diagnostic message for when we provide a non-zero integer value to the flags Resolves #126575
1 parent c60db55 commit b3963d3

File tree

7 files changed

+168
-14
lines changed

7 files changed

+168
-14
lines changed

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,5 +1856,6 @@ def err_hlsl_unexpected_end_of_params
18561856
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
18571857
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
18581858
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
1859+
def err_hlsl_rootsig_non_zero_flag : Error<"flag value is neither a literal 0 nor a named value">;
18591860

18601861
} // end of Parser diagnostics

clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#endif
2828

2929
// Defines the various types of enum
30+
#ifndef ROOT_FLAG_ENUM
31+
#define ROOT_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT)
32+
#endif
3033
#ifndef UNBOUNDED_ENUM
3134
#define UNBOUNDED_ENUM(NAME, LIT) ENUM(NAME, LIT)
3235
#endif
@@ -74,6 +77,7 @@ PUNCTUATOR(minus, '-')
7477

7578
// RootElement Keywords:
7679
KEYWORD(RootSignature) // used only for diagnostic messaging
80+
KEYWORD(RootFlags)
7781
KEYWORD(DescriptorTable)
7882
KEYWORD(RootConstants)
7983

@@ -101,6 +105,20 @@ UNBOUNDED_ENUM(unbounded, "unbounded")
101105
// Descriptor Range Offset Enum:
102106
DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND")
103107

108+
// Root Flag Enums:
109+
ROOT_FLAG_ENUM(AllowInputAssemblerInputLayout, "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT")
110+
ROOT_FLAG_ENUM(DenyVertexShaderRootAccess, "DENY_VERTEX_SHADER_ROOT_ACCESS")
111+
ROOT_FLAG_ENUM(DenyHullShaderRootAccess, "DENY_HULL_SHADER_ROOT_ACCESS")
112+
ROOT_FLAG_ENUM(DenyDomainShaderRootAccess, "DENY_DOMAIN_SHADER_ROOT_ACCESS")
113+
ROOT_FLAG_ENUM(DenyGeometryShaderRootAccess, "DENY_GEOMETRY_SHADER_ROOT_ACCESS")
114+
ROOT_FLAG_ENUM(DenyPixelShaderRootAccess, "DENY_PIXEL_SHADER_ROOT_ACCESS")
115+
ROOT_FLAG_ENUM(DenyAmplificationShaderRootAccess, "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS")
116+
ROOT_FLAG_ENUM(DenyMeshShaderRootAccess, "DENY_MESH_SHADER_ROOT_ACCESS")
117+
ROOT_FLAG_ENUM(AllowStreamOutput, "ALLOW_STREAM_OUTPUT")
118+
ROOT_FLAG_ENUM(LocalRootSignature, "LOCAL_ROOT_SIGNATURE")
119+
ROOT_FLAG_ENUM(CBVSRVUAVHeapDirectlyIndexed, "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED")
120+
ROOT_FLAG_ENUM(SamplerHeapDirectlyIndexed , "SAMPLER_HEAP_DIRECTLY_INDEXED")
121+
104122
// Root Descriptor Flag Enums:
105123
ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE")
106124
ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE")
@@ -128,6 +146,7 @@ SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH")
128146
#undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
129147
#undef DESCRIPTOR_RANGE_FLAG_ENUM_ON
130148
#undef ROOT_DESCRIPTOR_FLAG_ENUM
149+
#undef ROOT_FLAG_ENUM
131150
#undef DESCRIPTOR_RANGE_OFFSET_ENUM
132151
#undef UNBOUNDED_ENUM
133152
#undef ENUM

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class RootSignatureParser {
7171
// expected, or, there is a lexing error
7272

7373
/// Root Element parse methods:
74+
std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags();
7475
std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants();
7576
std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
7677
std::optional<llvm::hlsl::rootsig::DescriptorTableClause>

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
2727
bool RootSignatureParser::parse() {
2828
// Iterate as many RootElements as possible
2929
do {
30+
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
31+
auto Flags = parseRootFlags();
32+
if (!Flags.has_value())
33+
return true;
34+
Elements.push_back(*Flags);
35+
}
36+
3037
if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
3138
auto Constants = parseRootConstants();
3239
if (!Constants.has_value())
@@ -47,6 +54,61 @@ bool RootSignatureParser::parse() {
4754
/*param of=*/TokenKind::kw_RootSignature);
4855
}
4956

57+
template <typename FlagType>
58+
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
59+
if (!Flags.has_value())
60+
return Flag;
61+
62+
return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
63+
llvm::to_underlying(Flag));
64+
}
65+
66+
std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
67+
assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
68+
"Expects to only be invoked starting at given keyword");
69+
70+
if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
71+
CurToken.TokKind))
72+
return std::nullopt;
73+
74+
std::optional<RootFlags> Flags = RootFlags::None;
75+
76+
// Handle the edge-case of '0' to specify no flags set
77+
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
78+
if (!verifyZeroFlag()) {
79+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
80+
return std::nullopt;
81+
}
82+
} else {
83+
// Otherwise, parse as many flags as possible
84+
TokenKind Expected[] = {
85+
#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
86+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
87+
};
88+
89+
do {
90+
if (tryConsumeExpectedToken(Expected)) {
91+
switch (CurToken.TokKind) {
92+
#define ROOT_FLAG_ENUM(NAME, LIT) \
93+
case TokenKind::en_##NAME: \
94+
Flags = maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
95+
break;
96+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
97+
default:
98+
llvm_unreachable("Switch for consumed enum token was not provided");
99+
}
100+
}
101+
} while (tryConsumeExpectedToken(TokenKind::pu_or));
102+
}
103+
104+
if (consumeExpectedToken(TokenKind::pu_r_paren,
105+
diag::err_hlsl_unexpected_end_of_params,
106+
/*param of=*/TokenKind::kw_RootFlags))
107+
return std::nullopt;
108+
109+
return Flags;
110+
}
111+
50112
std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
51113
assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
52114
"Expects to only be invoked starting at given keyword");
@@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() {
467529
return std::nullopt;
468530
}
469531

470-
template <typename FlagType>
471-
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
472-
if (!Flags.has_value())
473-
return Flag;
474-
475-
return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
476-
llvm::to_underlying(Flag));
477-
}
478-
479532
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
480533
RootSignatureParser::parseDescriptorRangeFlags() {
481534
assert(CurToken.TokKind == TokenKind::pu_equal &&
@@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() {
484537
// Handle the edge-case of '0' to specify no flags set
485538
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
486539
if (!verifyZeroFlag()) {
487-
getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
540+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
488541
return std::nullopt;
489542
}
490543
return DescriptorRangeFlags::None;

clang/unittests/Lex/LexHLSLRootSignatureTest.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
128128
129129
RootSignature
130130
131-
DescriptorTable RootConstants
131+
RootFlags DescriptorTable RootConstants
132132
133133
num32BitConstants
134134
@@ -139,6 +139,19 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
139139
unbounded
140140
DESCRIPTOR_RANGE_OFFSET_APPEND
141141
142+
allow_input_assembler_input_layout
143+
deny_vertex_shader_root_access
144+
deny_hull_shader_root_access
145+
deny_domain_shader_root_access
146+
deny_geometry_shader_root_access
147+
deny_pixel_shader_root_access
148+
deny_amplification_shader_root_access
149+
deny_mesh_shader_root_access
150+
allow_stream_output
151+
local_root_signature
152+
cbv_srv_uav_heap_directly_indexed
153+
sampler_heap_directly_indexed
154+
142155
DATA_VOLATILE
143156
DATA_STATIC_WHILE_SET_AT_EXECUTE
144157
DATA_STATIC

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
294294
ASSERT_TRUE(Consumer->isSatisfied());
295295
}
296296

297+
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
298+
const llvm::StringLiteral Source = R"cc(
299+
RootFlags(),
300+
RootFlags(0),
301+
RootFlags(
302+
deny_domain_shader_root_access |
303+
deny_pixel_shader_root_access |
304+
local_root_signature |
305+
cbv_srv_uav_heap_directly_indexed |
306+
deny_amplification_shader_root_access |
307+
deny_geometry_shader_root_access |
308+
deny_hull_shader_root_access |
309+
deny_mesh_shader_root_access |
310+
allow_stream_output |
311+
sampler_heap_directly_indexed |
312+
allow_input_assembler_input_layout |
313+
deny_vertex_shader_root_access
314+
)
315+
)cc";
316+
317+
TrivialModuleLoader ModLoader;
318+
auto PP = createPP(Source, ModLoader);
319+
auto TokLoc = SourceLocation();
320+
321+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
322+
SmallVector<RootElement> Elements;
323+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
324+
325+
// Test no diagnostics produced
326+
Consumer->setNoDiag();
327+
328+
ASSERT_FALSE(Parser.parse());
329+
330+
ASSERT_EQ(Elements.size(), 3u);
331+
332+
RootElement Elem = Elements[0];
333+
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
334+
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
335+
336+
Elem = Elements[1];
337+
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
338+
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
339+
340+
Elem = Elements[2];
341+
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
342+
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::ValidFlags);
343+
344+
ASSERT_TRUE(Consumer->isSatisfied());
345+
}
346+
297347
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
298348
// This test will checks we can handling trailing commas ','
299349
const llvm::StringLiteral Source = R"cc(
@@ -566,7 +616,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
566616
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
567617

568618
// Test correct diagnostic produced
569-
Consumer->setExpected(diag::err_expected);
619+
Consumer->setExpected(diag::err_hlsl_rootsig_non_zero_flag);
570620
ASSERT_TRUE(Parser.parse());
571621

572622
ASSERT_TRUE(Consumer->isSatisfied());

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@ namespace rootsig {
2323

2424
// Definition of the various enumerations and flags
2525

26+
enum class RootFlags : uint32_t {
27+
None = 0,
28+
AllowInputAssemblerInputLayout = 0x1,
29+
DenyVertexShaderRootAccess = 0x2,
30+
DenyHullShaderRootAccess = 0x4,
31+
DenyDomainShaderRootAccess = 0x8,
32+
DenyGeometryShaderRootAccess = 0x10,
33+
DenyPixelShaderRootAccess = 0x20,
34+
AllowStreamOutput = 0x40,
35+
LocalRootSignature = 0x80,
36+
DenyAmplificationShaderRootAccess = 0x100,
37+
DenyMeshShaderRootAccess = 0x200,
38+
CBVSRVUAVHeapDirectlyIndexed = 0x400,
39+
SamplerHeapDirectlyIndexed = 0x800,
40+
ValidFlags = 0x00000fff
41+
};
42+
2643
enum class DescriptorRangeFlags : unsigned {
2744
None = 0,
2845
DescriptorsVolatile = 0x1,
@@ -97,8 +114,8 @@ struct DescriptorTableClause {
97114
};
98115

99116
// Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause
100-
using RootElement =
101-
std::variant<RootConstants, DescriptorTable, DescriptorTableClause>;
117+
using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
118+
DescriptorTableClause>;
102119

103120
} // namespace rootsig
104121
} // namespace hlsl

0 commit comments

Comments
 (0)