Skip to content

Commit bcc4357

Browse files
committed
add support for optional parameters
- use numDescriptors as an example
1 parent 56664e2 commit bcc4357

File tree

4 files changed

+82
-1
lines changed

4 files changed

+82
-1
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,22 @@ class RootSignatureParser {
106106
bool ParseDescriptorTableClause();
107107

108108
// Helper dispatch method
109+
//
110+
// These will switch on the Variant kind to dispatch to the respective Parse
111+
// method and store the parsed value back into Ref.
112+
//
113+
// It is helpful to have a generalized dispatch method so that when we need
114+
// to parse multiple optional parameters in any order, we can invoke this
115+
// method
116+
bool ParseParam(rs::ParamType Ref);
117+
118+
// Parse as many optional parameters as possible in any order
119+
bool
120+
ParseOptionalParams(llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap);
121+
122+
// Common parsing helpers
109123
bool ParseRegister(rs::Register *Reg);
124+
bool ParseUInt(uint32_t *X);
110125

111126
// Increment the token iterator if we have not reached the end.
112127
// Return value denotes if we were already at the last token.

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,71 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
325325
if (ParseRegister(&Clause.Register))
326326
return true;
327327

328+
// Parse optional paramaters
329+
llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap = {
330+
{TokenKind::kw_numDescriptors, &Clause.NumDescriptors},
331+
};
332+
if (ParseOptionalParams({RefMap}))
333+
return true;
334+
328335
if (ConsumeExpectedToken(TokenKind::pu_r_paren))
329336
return true;
330337

331338
Elements.push_back(Clause);
332339
return false;
333340
}
334341

342+
// Helper struct so that we can use the overloaded notation of std::visit
343+
template <class... Ts> struct OverloadedMethods : Ts... {
344+
using Ts::operator()...;
345+
};
346+
template <class... Ts> OverloadedMethods(Ts...) -> OverloadedMethods<Ts...>;
347+
348+
bool RootSignatureParser::ParseParam(ParamType Ref) {
349+
if (ConsumeExpectedToken(TokenKind::pu_equal))
350+
return true;
351+
352+
bool Error;
353+
std::visit(OverloadedMethods{[&](uint32_t *X) { Error = ParseUInt(X); },
354+
}, Ref);
355+
356+
return Error;
357+
}
358+
359+
bool RootSignatureParser::ParseOptionalParams(
360+
llvm::SmallDenseMap<TokenKind, rs::ParamType> RefMap) {
361+
SmallVector<TokenKind> ParamKeywords;
362+
for (auto RefPair : RefMap)
363+
ParamKeywords.push_back(RefPair.first);
364+
365+
// Keep track of which keywords have been seen to report duplicates
366+
llvm::SmallDenseSet<TokenKind> Seen;
367+
368+
while (!TryConsumeExpectedToken(TokenKind::pu_comma)) {
369+
if (ConsumeExpectedToken(ParamKeywords))
370+
return true;
371+
372+
TokenKind ParamKind = CurTok->Kind;
373+
if (Seen.contains(ParamKind)) {
374+
return true;
375+
}
376+
Seen.insert(ParamKind);
377+
378+
if (ParseParam(RefMap[ParamKind]))
379+
return true;
380+
}
381+
382+
return false;
383+
}
384+
385+
bool RootSignatureParser::ParseUInt(uint32_t *X) {
386+
if (ConsumeExpectedToken(TokenKind::int_literal))
387+
return true;
388+
389+
*X = CurTok->NumLiteral.getInt().getExtValue();
390+
return false;
391+
}
392+
335393
bool RootSignatureParser::ParseRegister(Register *Register) {
336394
switch (CurTok->Kind) {
337395
case TokenKind::bReg:

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
309309
const llvm::StringLiteral Source = R"cc(
310310
DescriptorTable(
311311
CBV(b0),
312-
SRV(t42),
312+
SRV(t42, numDescriptors = 4),
313313
Sampler(s987),
314314
UAV(u987234)
315315
),
@@ -338,6 +338,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
338338
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
339339
RegisterType::BReg);
340340
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, (uint32_t)0);
341+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
341342

342343
Elem = Elements[1];
343344
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -346,6 +347,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
346347
RegisterType::TReg);
347348
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
348349
(uint32_t)42);
350+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)4);
349351

350352
Elem = Elements[2];
351353
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -354,6 +356,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
354356
RegisterType::SReg);
355357
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
356358
(uint32_t)987);
359+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
357360

358361
Elem = Elements[3];
359362
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -362,6 +365,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
362365
RegisterType::UReg);
363366
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
364367
(uint32_t)987234);
368+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
365369

366370
Elem = Elements[4];
367371
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,15 @@ using ClauseType = llvm::dxil::ResourceClass;
4040
struct DescriptorTableClause {
4141
ClauseType Type;
4242
Register Register;
43+
uint32_t NumDescriptors = 1;
4344
};
4445

4546
// Models RootElement : DescriptorTable | DescriptorTableClause
4647
using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
4748

49+
// Models a reference to all assignment parameter types that any RootElement
50+
// may have. Things of the form: Keyword = Param
51+
using ParamType = std::variant<uint32_t *>;
4852
} // namespace root_signature
4953
} // namespace hlsl
5054
} // namespace llvm

0 commit comments

Comments
 (0)