Skip to content

[HLSL][RootSignature] Add parsing of floats for StaticSampler #140181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion clang/include/clang/Basic/DiagnosticParseKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -1856,7 +1856,11 @@ def err_hlsl_unexpected_end_of_params
: Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
def err_hlsl_number_literal_overflow : Error<
"%select{integer|float}0 literal is too large to be represented as a "
"%select{32-bit %select{signed|}1 integer|float}0 type">;
def err_hlsl_number_literal_underflow : Error<
"float literal has a magnitude that is too small to be represented as a float type">;
def err_hlsl_rootsig_non_zero_flag : Error<"flag value is neither a literal 0 nor a named value">;

} // end of Parser diagnostics
3 changes: 3 additions & 0 deletions clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ KEYWORD(flags)
KEYWORD(numDescriptors)
KEYWORD(offset)

// StaticSampler Keywords:
KEYWORD(mipLODBias)

// Unbounded Enum:
UNBOUNDED_ENUM(unbounded, "unbounded")

Expand Down
15 changes: 15 additions & 0 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,14 @@ class RootSignatureParser {

struct ParsedStaticSamplerParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<float> MipLODBias;
};
std::optional<ParsedStaticSamplerParams> parseStaticSamplerParams();

// Common parsing methods
std::optional<uint32_t> parseUIntParam();
std::optional<llvm::hlsl::rootsig::Register> parseRegister();
std::optional<float> parseFloatParam();

/// Parsing methods of various enums
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
Expand All @@ -128,6 +130,19 @@ class RootSignatureParser {
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
/// 32-bit integer
std::optional<uint32_t> handleUIntLiteral();
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a signed
/// 32-bit integer
std::optional<int32_t> handleIntLiteral(bool Negated);
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a float
///
/// This matches the behaviour of DXC, which is as follows:
/// - convert the spelling with `strtod`
/// - check for a float overflow
/// - cast the double to a float
/// The behaviour of `strtod` is replicated using:
/// Semantics: llvm::APFloat::Semantics::S_IEEEdouble
/// RoundingMode: llvm::RoundingMode::NearestTiesToEven
std::optional<float> handleFloatLiteral(bool Negated);

/// Flags may specify the value of '0' to denote that there should be no
/// flags set.
Expand Down
161 changes: 157 additions & 4 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {

Sampler.Reg = Params->Reg.value();

// Fill in optional values
if (Params->MipLODBias.has_value())
Sampler.MipLODBias = Params->MipLODBias.value();

if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_StaticSampler))
Expand Down Expand Up @@ -661,6 +665,23 @@ RootSignatureParser::parseStaticSamplerParams() {
return std::nullopt;
Params.Reg = Reg;
}

// `mipLODBias` `=` NUMBER
if (tryConsumeExpectedToken(TokenKind::kw_mipLODBias)) {
if (Params.MipLODBias.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
<< CurToken.TokKind;
return std::nullopt;
}

if (consumeExpectedToken(TokenKind::pu_equal))
return std::nullopt;

auto MipLODBias = parseFloatParam();
if (!MipLODBias.has_value())
return std::nullopt;
Params.MipLODBias = MipLODBias;
}
} while (tryConsumeExpectedToken(TokenKind::pu_comma));

return Params;
Expand Down Expand Up @@ -709,6 +730,39 @@ std::optional<Register> RootSignatureParser::parseRegister() {
return Reg;
}

std::optional<float> RootSignatureParser::parseFloatParam() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
"Expects to only be invoked starting at given keyword");
// Consume sign modifier
bool Signed =
tryConsumeExpectedToken({TokenKind::pu_plus, TokenKind::pu_minus});
bool Negated = Signed && CurToken.TokKind == TokenKind::pu_minus;

// DXC will treat a postive signed integer as unsigned
if (!Negated && tryConsumeExpectedToken(TokenKind::int_literal)) {
std::optional<uint32_t> UInt = handleUIntLiteral();
if (!UInt.has_value())
return std::nullopt;
return float(UInt.value());
}

if (Negated && tryConsumeExpectedToken(TokenKind::int_literal)) {
std::optional<int32_t> Int = handleIntLiteral(Negated);
if (!Int.has_value())
return std::nullopt;
return float(Int.value());
}

if (tryConsumeExpectedToken(TokenKind::float_literal)) {
std::optional<float> Float = handleFloatLiteral(Negated);
if (!Float.has_value())
return std::nullopt;
return Float.value();
}

return std::nullopt;
}

std::optional<llvm::hlsl::rootsig::ShaderVisibility>
RootSignatureParser::parseShaderVisibility() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
Expand Down Expand Up @@ -819,22 +873,121 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
PP.getSourceManager(), PP.getLangOpts(),
PP.getTargetInfo(), PP.getDiagnostics());
if (Literal.hadError)
return true; // Error has already been reported so just return
return std::nullopt; // Error has already been reported so just return

assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
assert(Literal.isIntegerLiteral() &&
"NumSpelling can only consist of digits");

llvm::APSInt Val = llvm::APSInt(32, false);
llvm::APSInt Val(32, /*IsUnsigned=*/true);
if (Literal.GetIntegerValue(Val)) {
// Report that the value has overflowed
PP.getDiagnostics().Report(CurToken.TokLoc,
diag::err_hlsl_number_literal_overflow)
<< 0 << CurToken.NumSpelling;
<< /*integer type*/ 0 << /*is signed*/ 0;
return std::nullopt;
}

return Val.getExtValue();
}

std::optional<int32_t> RootSignatureParser::handleIntLiteral(bool Negated) {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
PP.getSourceManager(), PP.getLangOpts(),
PP.getTargetInfo(), PP.getDiagnostics());
if (Literal.hadError)
return std::nullopt; // Error has already been reported so just return

assert(Literal.isIntegerLiteral() &&
"NumSpelling can only consist of digits");

llvm::APSInt Val(32, /*IsUnsigned=*/true);
// GetIntegerValue will overwrite Val from the parsed Literal and return
// true if it overflows as a 32-bit unsigned int
bool Overflowed = Literal.GetIntegerValue(Val);

// So we then need to check that it doesn't overflow as a 32-bit signed int:
int64_t MaxNegativeMagnitude = -int64_t(std::numeric_limits<int32_t>::min());
Overflowed |= (Negated && MaxNegativeMagnitude < Val.getExtValue());

int64_t MaxPositiveMagnitude = int64_t(std::numeric_limits<int32_t>::max());
Overflowed |= (!Negated && MaxPositiveMagnitude < Val.getExtValue());

if (Overflowed) {
// Report that the value has overflowed
PP.getDiagnostics().Report(CurToken.TokLoc,
diag::err_hlsl_number_literal_overflow)
<< /*integer type*/ 0 << /*is signed*/ 1;
return std::nullopt;
}

if (Negated)
Val = -Val;

return int32_t(Val.getExtValue());
}

std::optional<float> RootSignatureParser::handleFloatLiteral(bool Negated) {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
PP.getSourceManager(), PP.getLangOpts(),
PP.getTargetInfo(), PP.getDiagnostics());
if (Literal.hadError)
return std::nullopt; // Error has already been reported so just return

assert(Literal.isFloatingLiteral() &&
"NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
"will be caught and reported by NumericLiteralParser.");

// DXC used `strtod` to convert the token string to a float which corresponds
// to:
auto DXCSemantics = llvm::APFloat::Semantics::S_IEEEdouble;
auto DXCRoundingMode = llvm::RoundingMode::NearestTiesToEven;

llvm::APFloat Val(llvm::APFloat::EnumToSemantics(DXCSemantics));
llvm::APFloat::opStatus Status(Literal.GetFloatValue(Val, DXCRoundingMode));

// Note: we do not error when opStatus::opInexact by itself as this just
// denotes that rounding occured but not that it is invalid
assert(!(Status & llvm::APFloat::opStatus::opInvalidOp) &&
"NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
"will be caught and reported by NumericLiteralParser.");

assert(!(Status & llvm::APFloat::opStatus::opDivByZero) &&
"It is not possible for a division to be performed when "
"constructing an APFloat from a string");

if (Status & llvm::APFloat::opStatus::opUnderflow) {
// Report that the value has underflowed
PP.getDiagnostics().Report(CurToken.TokLoc,
diag::err_hlsl_number_literal_underflow);
return std::nullopt;
}

if (Status & llvm::APFloat::opStatus::opOverflow) {
// Report that the value has overflowed
PP.getDiagnostics().Report(CurToken.TokLoc,
diag::err_hlsl_number_literal_overflow)
<< /*float type*/ 1;
return std::nullopt;
}

if (Negated)
Val = -Val;

double DoubleVal = Val.convertToDouble();
double FloatMax = double(std::numeric_limits<float>::max());
if (FloatMax < DoubleVal || DoubleVal < -FloatMax) {
// Report that the value has overflowed
PP.getDiagnostics().Report(CurToken.TokLoc,
diag::err_hlsl_number_literal_overflow)
<< /*float type*/ 1;
return std::nullopt;
}

return static_cast<float>(DoubleVal);
}

bool RootSignatureParser::verifyZeroFlag() {
assert(CurToken.TokKind == TokenKind::int_literal);
auto X = handleUIntLiteral();
Expand Down
2 changes: 2 additions & 0 deletions clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
space visibility flags
numDescriptors offset

mipLODBias

unbounded
DESCRIPTOR_RANGE_OFFSET_APPEND

Expand Down
Loading
Loading