Skip to content

Commit b499f7f

Browse files
authored
[HLSL][RootSignature] Add parsing for empty RootDescriptors (#140147)
- define the RootDescriptor in-memory struct containing its type - add test harness for testing First part of #126577
1 parent 4f869e0 commit b499f7f

File tree

4 files changed

+91
-4
lines changed

4 files changed

+91
-4
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class RootSignatureParser {
7373
/// Root Element parse methods:
7474
std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags();
7575
std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants();
76+
std::optional<llvm::hlsl::rootsig::RootDescriptor> parseRootDescriptor();
7677
std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
7778
std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
7879
parseDescriptorTableClause();

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ bool RootSignatureParser::parse() {
4747
return true;
4848
Elements.push_back(*Table);
4949
}
50+
51+
if (tryConsumeExpectedToken(
52+
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
53+
auto Descriptor = parseRootDescriptor();
54+
if (!Descriptor.has_value())
55+
return true;
56+
Elements.push_back(*Descriptor);
57+
}
5058
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
5159

5260
return consumeExpectedToken(TokenKind::end_of_stream,
@@ -155,6 +163,41 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
155163
return Constants;
156164
}
157165

166+
std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
167+
assert((CurToken.TokKind == TokenKind::kw_CBV ||
168+
CurToken.TokKind == TokenKind::kw_SRV ||
169+
CurToken.TokKind == TokenKind::kw_UAV) &&
170+
"Expects to only be invoked starting at given keyword");
171+
172+
TokenKind DescriptorKind = CurToken.TokKind;
173+
174+
if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
175+
CurToken.TokKind))
176+
return std::nullopt;
177+
178+
RootDescriptor Descriptor;
179+
switch (DescriptorKind) {
180+
default:
181+
llvm_unreachable("Switch for consumed token was not provided");
182+
case TokenKind::kw_CBV:
183+
Descriptor.Type = DescriptorType::CBuffer;
184+
break;
185+
case TokenKind::kw_SRV:
186+
Descriptor.Type = DescriptorType::SRV;
187+
break;
188+
case TokenKind::kw_UAV:
189+
Descriptor.Type = DescriptorType::UAV;
190+
break;
191+
}
192+
193+
if (consumeExpectedToken(TokenKind::pu_r_paren,
194+
diag::err_hlsl_unexpected_end_of_params,
195+
/*param of=*/TokenKind::kw_RootConstants))
196+
return std::nullopt;
197+
198+
return Descriptor;
199+
}
200+
158201
std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
159202
assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
160203
"Expects to only be invoked starting at given keyword");

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,43 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
344344
ASSERT_TRUE(Consumer->isSatisfied());
345345
}
346346

347+
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
348+
const llvm::StringLiteral Source = R"cc(
349+
CBV(),
350+
SRV(),
351+
UAV()
352+
)cc";
353+
354+
TrivialModuleLoader ModLoader;
355+
auto PP = createPP(Source, ModLoader);
356+
auto TokLoc = SourceLocation();
357+
358+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
359+
SmallVector<RootElement> Elements;
360+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
361+
362+
// Test no diagnostics produced
363+
Consumer->setNoDiag();
364+
365+
ASSERT_FALSE(Parser.parse());
366+
367+
ASSERT_EQ(Elements.size(), 3u);
368+
369+
RootElement Elem = Elements[0];
370+
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
371+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer);
372+
373+
Elem = Elements[1];
374+
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
375+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::SRV);
376+
377+
Elem = Elements[2];
378+
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
379+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::UAV);
380+
381+
ASSERT_TRUE(Consumer->isSatisfied());
382+
}
383+
347384
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
348385
// This test will checks we can handling trailing commas ','
349386
const llvm::StringLiteral Source = R"cc(

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ struct RootConstants {
8585
ShaderVisibility Visibility = ShaderVisibility::All;
8686
};
8787

88+
using DescriptorType = llvm::dxil::ResourceClass;
89+
// Models RootDescriptor : CBV | SRV | UAV, by collecting like parameters
90+
struct RootDescriptor {
91+
DescriptorType Type;
92+
};
93+
8894
// Models the end of a descriptor table and stores its visibility
8995
struct DescriptorTable {
9096
ShaderVisibility Visibility = ShaderVisibility::All;
@@ -125,8 +131,8 @@ struct DescriptorTableClause {
125131
void dump(raw_ostream &OS) const;
126132
};
127133

128-
/// Models RootElement : RootFlags | RootConstants | DescriptorTable
129-
/// | DescriptorTableClause
134+
/// Models RootElement : RootFlags | RootConstants | RootDescriptor
135+
/// | DescriptorTable | DescriptorTableClause
130136
///
131137
/// A Root Signature is modeled in-memory by an array of RootElements. These
132138
/// aim to map closely to their DSL grammar reprsentation defined in the spec.
@@ -140,8 +146,8 @@ struct DescriptorTableClause {
140146
/// The DescriptorTable is modelled by having its Clauses as the previous
141147
/// RootElements in the array, and it holds a data member for the Visibility
142148
/// parameter.
143-
using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
144-
DescriptorTableClause>;
149+
using RootElement = std::variant<RootFlags, RootConstants, RootDescriptor,
150+
DescriptorTable, DescriptorTableClause>;
145151

146152
void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements);
147153

0 commit comments

Comments
 (0)