Skip to content

Commit b5d5708

Browse files
joaosaffranjoaosaffran
andauthored
[HLSL] Add descriptor table metadata parsing (#142492)
Implements descriptor table parsing from root signature metadata. This is required to support root signatures in hlsl. Closes: #[126640](#126640) --------- Co-authored-by: joaosaffran <[email protected]>
1 parent 8d2eea9 commit b5d5708

12 files changed

+559
-35
lines changed

llvm/include/llvm/BinaryFormat/DXContainer.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LLVM_BINARYFORMAT_DXCONTAINER_H
1414
#define LLVM_BINARYFORMAT_DXCONTAINER_H
1515

16+
#include "llvm/ADT/BitmaskEnum.h"
1617
#include "llvm/ADT/StringRef.h"
1718
#include "llvm/Support/Compiler.h"
1819
#include "llvm/Support/Error.h"
@@ -40,6 +41,8 @@ template <typename T> struct EnumEntry;
4041

4142
namespace dxbc {
4243

44+
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
45+
4346
inline Triple::EnvironmentType getShaderStage(uint32_t Kind) {
4447
assert(Kind <= Triple::Amplification - Triple::Pixel &&
4548
"Shader kind out of expected range.");
@@ -167,6 +170,8 @@ enum class RootDescriptorFlag : uint32_t {
167170
#define DESCRIPTOR_RANGE_FLAG(Num, Val) Val = Num,
168171
enum class DescriptorRangeFlag : uint32_t {
169172
#include "DXContainerConstants.def"
173+
174+
LLVM_MARK_AS_BITMASK_ENUM(DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS)
170175
};
171176

172177
#define ROOT_PARAMETER(Val, Enum) Enum = Val,

llvm/include/llvm/BinaryFormat/DXContainerConstants.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,16 @@ DESCRIPTOR_RANGE_FLAG(0x10000, DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS)
9999
#undef DESCRIPTOR_RANGE_FLAG
100100
#endif // DESCRIPTOR_RANGE_FLAG
101101

102+
// DESCRIPTOR_RANGE(value, name).
103+
#ifdef DESCRIPTOR_RANGE
104+
105+
DESCRIPTOR_RANGE(0, SRV)
106+
DESCRIPTOR_RANGE(1, UAV)
107+
DESCRIPTOR_RANGE(2, CBV)
108+
DESCRIPTOR_RANGE(3, Sampler)
109+
#undef DESCRIPTOR_RANGE
110+
#endif // DESCRIPTOR_RANGE
111+
102112
#ifdef ROOT_PARAMETER
103113

104114
ROOT_PARAMETER(0, DescriptorTable)

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 213 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,93 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
174174
return false;
175175
}
176176

177+
static bool parseDescriptorRange(LLVMContext *Ctx,
178+
mcdxbc::DescriptorTable &Table,
179+
MDNode *RangeDescriptorNode) {
180+
181+
if (RangeDescriptorNode->getNumOperands() != 6)
182+
return reportError(Ctx, "Invalid format for Descriptor Range");
183+
184+
dxbc::RTS0::v2::DescriptorRange Range;
185+
186+
std::optional<StringRef> ElementText =
187+
extractMdStringValue(RangeDescriptorNode, 0);
188+
189+
if (!ElementText.has_value())
190+
return reportError(Ctx, "Descriptor Range, first element is not a string.");
191+
192+
Range.RangeType =
193+
StringSwitch<uint32_t>(*ElementText)
194+
.Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
195+
.Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
196+
.Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
197+
.Case("Sampler",
198+
llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
199+
.Default(~0U);
200+
201+
if (Range.RangeType == ~0U)
202+
return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
203+
204+
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
205+
Range.NumDescriptors = *Val;
206+
else
207+
return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
208+
209+
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
210+
Range.BaseShaderRegister = *Val;
211+
else
212+
return reportError(Ctx, "Invalid value for BaseShaderRegister");
213+
214+
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
215+
Range.RegisterSpace = *Val;
216+
else
217+
return reportError(Ctx, "Invalid value for RegisterSpace");
218+
219+
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
220+
Range.OffsetInDescriptorsFromTableStart = *Val;
221+
else
222+
return reportError(Ctx,
223+
"Invalid value for OffsetInDescriptorsFromTableStart");
224+
225+
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
226+
Range.Flags = *Val;
227+
else
228+
return reportError(Ctx, "Invalid value for Descriptor Range Flags");
229+
230+
Table.Ranges.push_back(Range);
231+
return false;
232+
}
233+
234+
static bool parseDescriptorTable(LLVMContext *Ctx,
235+
mcdxbc::RootSignatureDesc &RSD,
236+
MDNode *DescriptorTableNode) {
237+
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
238+
if (NumOperands < 2)
239+
return reportError(Ctx, "Invalid format for Descriptor Table");
240+
241+
dxbc::RTS0::v1::RootParameterHeader Header;
242+
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
243+
Header.ShaderVisibility = *Val;
244+
else
245+
return reportError(Ctx, "Invalid value for ShaderVisibility");
246+
247+
mcdxbc::DescriptorTable Table;
248+
Header.ParameterType =
249+
llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
250+
251+
for (unsigned int I = 2; I < NumOperands; I++) {
252+
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
253+
if (Element == nullptr)
254+
return reportError(Ctx, "Missing Root Element Metadata Node.");
255+
256+
if (parseDescriptorRange(Ctx, Table, Element))
257+
return true;
258+
}
259+
260+
RSD.ParametersContainer.addParameter(Header, Table);
261+
return false;
262+
}
263+
177264
static bool parseRootSignatureElement(LLVMContext *Ctx,
178265
mcdxbc::RootSignatureDesc &RSD,
179266
MDNode *Element) {
@@ -188,6 +275,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
188275
.Case("RootCBV", RootSignatureElementKind::CBV)
189276
.Case("RootSRV", RootSignatureElementKind::SRV)
190277
.Case("RootUAV", RootSignatureElementKind::UAV)
278+
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
191279
.Default(RootSignatureElementKind::Error);
192280

193281
switch (ElementKind) {
@@ -200,6 +288,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
200288
case RootSignatureElementKind::SRV:
201289
case RootSignatureElementKind::UAV:
202290
return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
291+
case RootSignatureElementKind::DescriptorTable:
292+
return parseDescriptorTable(Ctx, RSD, Element);
203293
case RootSignatureElementKind::Error:
204294
return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
205295
}
@@ -241,6 +331,81 @@ static bool verifyRegisterSpace(uint32_t RegisterSpace) {
241331

242332
static bool verifyDescriptorFlag(uint32_t Flags) { return (Flags & ~0xE) == 0; }
243333

334+
static bool verifyRangeType(uint32_t Type) {
335+
switch (Type) {
336+
case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
337+
case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
338+
case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
339+
case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
340+
return true;
341+
};
342+
343+
return false;
344+
}
345+
346+
static bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
347+
uint32_t FlagsVal) {
348+
using FlagT = dxbc::DescriptorRangeFlag;
349+
FlagT Flags = FlagT(FlagsVal);
350+
351+
const bool IsSampler =
352+
(Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler));
353+
354+
if (Version == 1) {
355+
// Since the metadata is unversioned, we expect to explicitly see the values
356+
// that map to the version 1 behaviour here.
357+
if (IsSampler)
358+
return Flags == FlagT::DESCRIPTORS_VOLATILE;
359+
return Flags == (FlagT::DATA_VOLATILE | FlagT::DESCRIPTORS_VOLATILE);
360+
}
361+
362+
// The data-specific flags are mutually exclusive.
363+
FlagT DataFlags = FlagT::DATA_VOLATILE | FlagT::DATA_STATIC |
364+
FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
365+
366+
if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
367+
return false;
368+
369+
// The descriptor-specific flags are mutually exclusive.
370+
FlagT DescriptorFlags =
371+
FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS |
372+
FlagT::DESCRIPTORS_VOLATILE;
373+
if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1)
374+
return false;
375+
376+
// For volatile descriptors, DATA_STATIC is never valid.
377+
if ((Flags & FlagT::DESCRIPTORS_VOLATILE) == FlagT::DESCRIPTORS_VOLATILE) {
378+
FlagT Mask = FlagT::DESCRIPTORS_VOLATILE;
379+
if (!IsSampler) {
380+
Mask |= FlagT::DATA_VOLATILE;
381+
Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
382+
}
383+
return (Flags & ~Mask) == FlagT::NONE;
384+
}
385+
386+
// For "STATIC_KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
387+
// the other data-specific flags may all be set.
388+
if ((Flags & FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) ==
389+
FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) {
390+
FlagT Mask = FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS;
391+
if (!IsSampler) {
392+
Mask |= FlagT::DATA_VOLATILE;
393+
Mask |= FlagT::DATA_STATIC;
394+
Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
395+
}
396+
return (Flags & ~Mask) == FlagT::NONE;
397+
}
398+
399+
// When no descriptor flag is set, any data flag is allowed.
400+
FlagT Mask = FlagT::NONE;
401+
if (!IsSampler) {
402+
Mask |= FlagT::DATA_VOLATILE;
403+
Mask |= FlagT::DATA_STATIC;
404+
Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
405+
}
406+
return (Flags & ~Mask) == FlagT::NONE;
407+
}
408+
244409
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
245410

246411
if (!verifyVersion(RSD.Version)) {
@@ -275,7 +440,23 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
275440

276441
if (RSD.Version > 1) {
277442
if (!verifyDescriptorFlag(Descriptor.Flags))
278-
return reportValueError(Ctx, "DescriptorFlag", Descriptor.Flags);
443+
return reportValueError(Ctx, "DescriptorRangeFlag", Descriptor.Flags);
444+
}
445+
break;
446+
}
447+
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
448+
const mcdxbc::DescriptorTable &Table =
449+
RSD.ParametersContainer.getDescriptorTable(Info.Location);
450+
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
451+
if (!verifyRangeType(Range.RangeType))
452+
return reportValueError(Ctx, "RangeType", Range.RangeType);
453+
454+
if (!verifyRegisterSpace(Range.RegisterSpace))
455+
return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
456+
457+
if (!verifyDescriptorRangeFlag(RSD.Version, Range.RangeType,
458+
Range.Flags))
459+
return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
279460
}
280461
break;
281462
}
@@ -388,67 +569,67 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
388569

389570
OS << "Root Signature Definitions"
390571
<< "\n";
391-
uint8_t Space = 0;
392572
for (const Function &F : M) {
393573
auto It = RSDMap.find(&F);
394574
if (It == RSDMap.end())
395575
continue;
396576
const auto &RS = It->second;
397577
OS << "Definition for '" << F.getName() << "':\n";
398-
399578
// start root signature header
400-
Space++;
401-
OS << indent(Space) << "Flags: " << format_hex(RS.Flags, 8) << "\n";
402-
OS << indent(Space) << "Version: " << RS.Version << "\n";
403-
OS << indent(Space) << "RootParametersOffset: " << RS.RootParameterOffset
404-
<< "\n";
405-
OS << indent(Space) << "NumParameters: " << RS.ParametersContainer.size()
406-
<< "\n";
407-
Space++;
579+
OS << "Flags: " << format_hex(RS.Flags, 8) << "\n"
580+
<< "Version: " << RS.Version << "\n"
581+
<< "RootParametersOffset: " << RS.RootParameterOffset << "\n"
582+
<< "NumParameters: " << RS.ParametersContainer.size() << "\n";
408583
for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
409584
const auto &[Type, Loc] =
410585
RS.ParametersContainer.getTypeAndLocForParameter(I);
411586
const dxbc::RTS0::v1::RootParameterHeader Header =
412587
RS.ParametersContainer.getHeader(I);
413588

414-
OS << indent(Space) << "- Parameter Type: " << Type << "\n";
415-
OS << indent(Space + 2)
416-
<< "Shader Visibility: " << Header.ShaderVisibility << "\n";
589+
OS << "- Parameter Type: " << Type << "\n"
590+
<< " Shader Visibility: " << Header.ShaderVisibility << "\n";
417591

418592
switch (Type) {
419593
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
420594
const dxbc::RTS0::v1::RootConstants &Constants =
421595
RS.ParametersContainer.getConstant(Loc);
422-
OS << indent(Space + 2) << "Register Space: " << Constants.RegisterSpace
423-
<< "\n";
424-
OS << indent(Space + 2)
425-
<< "Shader Register: " << Constants.ShaderRegister << "\n";
426-
OS << indent(Space + 2)
427-
<< "Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
596+
OS << " Register Space: " << Constants.RegisterSpace << "\n"
597+
<< " Shader Register: " << Constants.ShaderRegister << "\n"
598+
<< " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
428599
break;
429600
}
430601
case llvm::to_underlying(dxbc::RootParameterType::CBV):
431602
case llvm::to_underlying(dxbc::RootParameterType::UAV):
432603
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
433604
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
434605
RS.ParametersContainer.getRootDescriptor(Loc);
435-
OS << indent(Space + 2)
436-
<< "Register Space: " << Descriptor.RegisterSpace << "\n";
437-
OS << indent(Space + 2)
438-
<< "Shader Register: " << Descriptor.ShaderRegister << "\n";
606+
OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
607+
<< " Shader Register: " << Descriptor.ShaderRegister << "\n";
439608
if (RS.Version > 1)
440-
OS << indent(Space + 2) << "Flags: " << Descriptor.Flags << "\n";
609+
OS << " Flags: " << Descriptor.Flags << "\n";
610+
break;
611+
}
612+
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
613+
const mcdxbc::DescriptorTable &Table =
614+
RS.ParametersContainer.getDescriptorTable(Loc);
615+
OS << " NumRanges: " << Table.Ranges.size() << "\n";
616+
617+
for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
618+
OS << " - Range Type: " << Range.RangeType << "\n"
619+
<< " Register Space: " << Range.RegisterSpace << "\n"
620+
<< " Base Shader Register: " << Range.BaseShaderRegister << "\n"
621+
<< " Num Descriptors: " << Range.NumDescriptors << "\n"
622+
<< " Offset In Descriptors From Table Start: "
623+
<< Range.OffsetInDescriptorsFromTableStart << "\n";
624+
if (RS.Version > 1)
625+
OS << " Flags: " << Range.Flags << "\n";
626+
}
441627
break;
442628
}
443629
}
444-
Space--;
445630
}
446-
OS << indent(Space) << "NumStaticSamplers: " << 0 << "\n";
447-
OS << indent(Space) << "StaticSamplersOffset: " << RS.StaticSamplersOffset
448-
<< "\n";
449-
450-
Space--;
451-
// end root signature header
631+
OS << "NumStaticSamplers: " << 0 << "\n";
632+
OS << "StaticSamplersOffset: " << RS.StaticSamplersOffset << "\n";
452633
}
453634
return PreservedAnalyses::all();
454635
}

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum class RootSignatureElementKind {
3131
SRV = 3,
3232
UAV = 4,
3333
CBV = 5,
34+
DescriptorTable = 6,
3435
};
3536
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
3637
friend AnalysisInfoMixin<RootSignatureAnalysis>;

0 commit comments

Comments
 (0)