Skip to content

[HLSL][RootSignature] Add parsing of DescriptorRangeFlags #136775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class RootSignatureParser {
struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<uint32_t> Space;
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
};
std::optional<ParsedClauseParams>
parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
Expand All @@ -91,11 +92,19 @@ class RootSignatureParser {

/// Parsing methods of various enums
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
parseDescriptorRangeFlags();

/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
/// 32-bit integer
std::optional<uint32_t> handleUIntLiteral();

/// Flags may specify the value of '0' to denote that there should be no
/// flags set.
///
/// Return true if the current int_literal token is '0', otherwise false
bool verifyZeroFlag();

/// Invoke the Lexer to consume a token and update CurToken with the result
void consumeNextToken() { CurToken = Lexer.consumeToken(); }

Expand Down
76 changes: 76 additions & 0 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
ExpectedReg = TokenKind::sReg;
break;
}
Clause.setDefaultFlags();

auto Params = parseDescriptorTableClauseParams(ExpectedReg);
if (!Params.has_value())
Expand All @@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
if (Params->Space.has_value())
Clause.Space = Params->Space.value();

if (Params->Flags.has_value())
Clause.Flags = Params->Flags.value();

if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/ParamKind))
Expand Down Expand Up @@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
return std::nullopt;
Params.Space = Space;
}

// `flags` `=` DESCRIPTOR_RANGE_FLAGS
if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
if (Params.Flags.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
<< CurToken.TokKind;
return std::nullopt;
}

if (consumeExpectedToken(TokenKind::pu_equal))
return std::nullopt;

auto Flags = parseDescriptorRangeFlags();
if (!Flags.has_value())
return std::nullopt;
Params.Flags = Flags;
}

} while (tryConsumeExpectedToken(TokenKind::pu_comma));

return Params;
Expand Down Expand Up @@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}

template <typename FlagType>
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
if (!Flags.has_value())
return Flag;

return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
llvm::to_underlying(Flag));
}

std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
RootSignatureParser::parseDescriptorRangeFlags() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
"Expects to only be invoked starting at given keyword");

// Handle the edge-case of '0' to specify no flags set
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
if (!verifyZeroFlag()) {
getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
return std::nullopt;
}
return DescriptorRangeFlags::None;
}

TokenKind Expected[] = {
#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
};

std::optional<DescriptorRangeFlags> Flags;

do {
if (tryConsumeExpectedToken(Expected)) {
switch (CurToken.TokKind) {
#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \
case TokenKind::en_##NAME: \
Flags = \
maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME); \
break;
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
default:
llvm_unreachable("Switch for consumed enum token was not provided");
}
}
} while (tryConsumeExpectedToken(TokenKind::pu_or));

return Flags;
}

std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
Expand All @@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
return Val.getExtValue();
}

bool RootSignatureParser::verifyZeroFlag() {
assert(CurToken.TokKind == TokenKind::int_literal);
auto X = handleUIntLiteral();
return X.has_value() && X.value() == 0;
}

bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
return peekExpectedToken(ArrayRef{Expected});
}
Expand Down
69 changes: 67 additions & 2 deletions clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
const llvm::StringLiteral Source = R"cc(
DescriptorTable(
CBV(b0),
SRV(space = 3, t42),
SRV(space = 3, t42, flags = 0),
visibility = SHADER_VISIBILITY_PIXEL,
Sampler(s987, space = +2),
UAV(u4294967294)
UAV(u4294967294,
flags = Descriptors_Volatile | Data_Volatile
| Data_Static_While_Set_At_Execute | Data_Static
| Descriptors_Static_Keeping_Buffer_Bounds_Checks
)
),
DescriptorTable()
)cc";
Expand All @@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::BReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
DescriptorRangeFlags::DataStaticWhileSetAtExecute);

Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
Expand All @@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::TReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
DescriptorRangeFlags::None);

Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
Expand All @@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::SReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
DescriptorRangeFlags::None);

Elem = Elements[3];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
Expand All @@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::UReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
DescriptorRangeFlags::ValidFlags);

Elem = Elements[4];
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
Expand All @@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
// This test will checks we can set the valid enum for Sampler descriptor
// range flags
const llvm::StringLiteral Source = R"cc(
DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
)cc";

TrivialModuleLoader ModLoader;
auto PP = createPP(Source, ModLoader);
auto TokLoc = SourceLocation();

hlsl::RootSignatureLexer Lexer(Source, TokLoc);
SmallVector<RootElement> Elements;
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);

// Test no diagnostics produced
Consumer->setNoDiag();

ASSERT_FALSE(Parser.parse());

RootElement Elem = Elements[0];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
DescriptorRangeFlags::ValidSamplerFlags);

ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
Expand Down Expand Up @@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
// This test will check that parsing fails when a non-zero integer literal
// is given to flags
const llvm::StringLiteral Source = R"cc(
DescriptorTable(
CBV(b0, flags = 3)
)
)cc";

TrivialModuleLoader ModLoader;
auto PP = createPP(Source, ModLoader);
auto TokLoc = SourceLocation();

hlsl::RootSignatureLexer Lexer(Source, TokLoc);
SmallVector<RootElement> Elements;
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);

// Test correct diagnostic produced
Consumer->setExpected(diag::err_expected);
ASSERT_TRUE(Parser.parse());

ASSERT_TRUE(Consumer->isSatisfied());
}

} // anonymous namespace
27 changes: 27 additions & 0 deletions llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ namespace rootsig {

// Definition of the various enumerations and flags

enum class DescriptorRangeFlags : unsigned {
None = 0,
DescriptorsVolatile = 0x1,
DataVolatile = 0x2,
DataStaticWhileSetAtExecute = 0x4,
DataStatic = 0x8,
DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
ValidFlags = 0x1000f,
ValidSamplerFlags = DescriptorsVolatile,
};

enum class ShaderVisibility {
All = 0,
Vertex = 1,
Expand Down Expand Up @@ -55,6 +66,22 @@ struct DescriptorTableClause {
ClauseType Type;
Register Reg;
uint32_t Space = 0;
DescriptorRangeFlags Flags;

void setDefaultFlags() {
switch (Type) {
case ClauseType::CBuffer:
case ClauseType::SRV:
Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
break;
case ClauseType::UAV:
Flags = DescriptorRangeFlags::DataVolatile;
break;
case ClauseType::Sampler:
Flags = DescriptorRangeFlags::None;
break;
}
}
};

// Models RootElement : DescriptorTable | DescriptorTableClause
Expand Down
Loading