Skip to content

Commit 7549f42

Browse files
authored
[HLSL][RootSignature] Add parsing of flags to RootDescriptor (#140152)
- defines RootDescriptorFlags in-memory representation - defines parseRootDescriptorFlags to be DXC compatible. This is why we support multiple `|` flags even though validation will assert that only one flag is set - add unit tests to demonstrate functionality Final part of and resolves #126577
1 parent 51a03ed commit 7549f42

File tree

4 files changed

+107
-4
lines changed

4 files changed

+107
-4
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class RootSignatureParser {
9393
std::optional<llvm::hlsl::rootsig::Register> Reg;
9494
std::optional<uint32_t> Space;
9595
std::optional<llvm::hlsl::rootsig::ShaderVisibility> Visibility;
96+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags> Flags;
9697
};
9798
std::optional<ParsedRootDescriptorParams>
9899
parseRootDescriptorParams(RootSignatureToken::Kind RegType);
@@ -113,6 +114,8 @@ class RootSignatureParser {
113114

114115
/// Parsing methods of various enums
115116
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
117+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
118+
parseRootDescriptorFlags();
116119
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
117120
parseDescriptorRangeFlags();
118121

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
193193
ExpectedReg = TokenKind::uReg;
194194
break;
195195
}
196+
Descriptor.setDefaultFlags();
196197

197198
auto Params = parseRootDescriptorParams(ExpectedReg);
198199
if (!Params.has_value())
@@ -214,6 +215,9 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
214215
if (Params->Visibility.has_value())
215216
Descriptor.Visibility = Params->Visibility.value();
216217

218+
if (Params->Flags.has_value())
219+
Descriptor.Flags = Params->Flags.value();
220+
217221
if (consumeExpectedToken(TokenKind::pu_r_paren,
218222
diag::err_hlsl_unexpected_end_of_params,
219223
/*param of=*/TokenKind::kw_RootConstants))
@@ -475,6 +479,23 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
475479
return std::nullopt;
476480
Params.Visibility = Visibility;
477481
}
482+
483+
// `flags` `=` ROOT_DESCRIPTOR_FLAGS
484+
if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
485+
if (Params.Flags.has_value()) {
486+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
487+
<< CurToken.TokKind;
488+
return std::nullopt;
489+
}
490+
491+
if (consumeExpectedToken(TokenKind::pu_equal))
492+
return std::nullopt;
493+
494+
auto Flags = parseRootDescriptorFlags();
495+
if (!Flags.has_value())
496+
return std::nullopt;
497+
Params.Flags = Flags;
498+
}
478499
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
479500

480501
return Params;
@@ -654,6 +675,45 @@ RootSignatureParser::parseShaderVisibility() {
654675
return std::nullopt;
655676
}
656677

678+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
679+
RootSignatureParser::parseRootDescriptorFlags() {
680+
assert(CurToken.TokKind == TokenKind::pu_equal &&
681+
"Expects to only be invoked starting at given keyword");
682+
683+
// Handle the edge-case of '0' to specify no flags set
684+
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
685+
if (!verifyZeroFlag()) {
686+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
687+
return std::nullopt;
688+
}
689+
return RootDescriptorFlags::None;
690+
}
691+
692+
TokenKind Expected[] = {
693+
#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
694+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
695+
};
696+
697+
std::optional<RootDescriptorFlags> Flags;
698+
699+
do {
700+
if (tryConsumeExpectedToken(Expected)) {
701+
switch (CurToken.TokKind) {
702+
#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) \
703+
case TokenKind::en_##NAME: \
704+
Flags = \
705+
maybeOrFlag<RootDescriptorFlags>(Flags, RootDescriptorFlags::NAME); \
706+
break;
707+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
708+
default:
709+
llvm_unreachable("Switch for consumed enum token was not provided");
710+
}
711+
}
712+
} while (tryConsumeExpectedToken(TokenKind::pu_or));
713+
714+
return Flags;
715+
}
716+
657717
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
658718
RootSignatureParser::parseDescriptorRangeFlags() {
659719
assert(CurToken.TokKind == TokenKind::pu_equal &&

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,11 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
347347
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
348348
const llvm::StringLiteral Source = R"cc(
349349
CBV(b0),
350-
SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY),
351-
UAV(visibility = SHADER_VISIBILITY_HULL, u34893247)
350+
SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY,
351+
flags = DATA_VOLATILE | DATA_STATIC | DATA_STATIC_WHILE_SET_AT_EXECUTE
352+
),
353+
UAV(visibility = SHADER_VISIBILITY_HULL, u34893247),
354+
CBV(b0, flags = 0),
352355
)cc";
353356

354357
TrivialModuleLoader ModLoader;
@@ -364,7 +367,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
364367

365368
ASSERT_FALSE(Parser.parse());
366369

367-
ASSERT_EQ(Elements.size(), 3u);
370+
ASSERT_EQ(Elements.size(), 4u);
368371

369372
RootElement Elem = Elements[0];
370373
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
@@ -373,6 +376,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
373376
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 0u);
374377
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
375378
ASSERT_EQ(std::get<RootDescriptor>(Elem).Visibility, ShaderVisibility::All);
379+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags,
380+
RootDescriptorFlags::DataStaticWhileSetAtExecute);
376381

377382
Elem = Elements[1];
378383
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
@@ -382,6 +387,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
382387
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 4u);
383388
ASSERT_EQ(std::get<RootDescriptor>(Elem).Visibility,
384389
ShaderVisibility::Geometry);
390+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags,
391+
RootDescriptorFlags::ValidFlags);
385392

386393
Elem = Elements[2];
387394
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
@@ -390,6 +397,18 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
390397
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 34893247u);
391398
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
392399
ASSERT_EQ(std::get<RootDescriptor>(Elem).Visibility, ShaderVisibility::Hull);
400+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags,
401+
RootDescriptorFlags::DataVolatile);
402+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags,
403+
RootDescriptorFlags::DataVolatile);
404+
405+
Elem = Elements[3];
406+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer);
407+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::BReg);
408+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 0u);
409+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
410+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Visibility, ShaderVisibility::All);
411+
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags, RootDescriptorFlags::None);
393412

394413
ASSERT_TRUE(Consumer->isSatisfied());
395414
}

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ enum class RootFlags : uint32_t {
4646
ValidFlags = 0x00000fff
4747
};
4848

49+
enum class RootDescriptorFlags : unsigned {
50+
None = 0,
51+
DataVolatile = 0x2,
52+
DataStaticWhileSetAtExecute = 0x4,
53+
DataStatic = 0x8,
54+
ValidFlags = 0xe,
55+
};
56+
4957
enum class DescriptorRangeFlags : unsigned {
5058
None = 0,
5159
DescriptorsVolatile = 0x1,
@@ -85,13 +93,26 @@ struct RootConstants {
8593
ShaderVisibility Visibility = ShaderVisibility::All;
8694
};
8795

88-
using DescriptorType = llvm::dxil::ResourceClass;
96+
enum class DescriptorType : uint8_t { SRV = 0, UAV, CBuffer };
8997
// Models RootDescriptor : CBV | SRV | UAV, by collecting like parameters
9098
struct RootDescriptor {
9199
DescriptorType Type;
92100
Register Reg;
93101
uint32_t Space = 0;
94102
ShaderVisibility Visibility = ShaderVisibility::All;
103+
RootDescriptorFlags Flags;
104+
105+
void setDefaultFlags() {
106+
switch (Type) {
107+
case DescriptorType::CBuffer:
108+
case DescriptorType::SRV:
109+
Flags = RootDescriptorFlags::DataStaticWhileSetAtExecute;
110+
break;
111+
case DescriptorType::UAV:
112+
Flags = RootDescriptorFlags::DataVolatile;
113+
break;
114+
}
115+
}
95116
};
96117

97118
// Models the end of a descriptor table and stores its visibility

0 commit comments

Comments
 (0)