Skip to content

Commit 63b80dd

Browse files
authored
[NFC][RootSignature] Use llvm::EnumEntry for serialization of Root Signature Elements (#144106)
It has pointed out [here](#143198 (comment)) that we may be able to use `llvm::EnumEntry` so that we can re-use the printing logic across enumerations. - Enables re-use of `printEnum` and `printFlags` methods via templates - Allows easy definition of `getEnumName` function for enum-to-string conversion, eliminating the need to use a string stream for constructing the Name SmallString - Also, does a small fix-up of the operands for descriptor table clause to be consistent with other `Build*` methods For reference, the [test-cases](https://github.com/llvm/llvm-project/blob/main/llvm/unittests/Frontend/HLSLRootSignatureDumpTest.cpp) that must not change expected output.
1 parent 8ed43c4 commit 63b80dd

File tree

1 file changed

+104
-105
lines changed

1 file changed

+104
-105
lines changed

llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp

Lines changed: 104 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -15,111 +15,46 @@
1515
#include "llvm/ADT/bit.h"
1616
#include "llvm/IR/IRBuilder.h"
1717
#include "llvm/IR/Metadata.h"
18+
#include "llvm/Support/ScopedPrinter.h"
1819

1920
namespace llvm {
2021
namespace hlsl {
2122
namespace rootsig {
2223

23-
static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
24-
switch (Reg.ViewType) {
25-
case RegisterType::BReg:
26-
OS << "b";
27-
break;
28-
case RegisterType::TReg:
29-
OS << "t";
30-
break;
31-
case RegisterType::UReg:
32-
OS << "u";
33-
break;
34-
case RegisterType::SReg:
35-
OS << "s";
36-
break;
37-
}
38-
OS << Reg.Number;
39-
return OS;
24+
template <typename T>
25+
static std::optional<StringRef> getEnumName(const T Value,
26+
ArrayRef<EnumEntry<T>> Enums) {
27+
for (const auto &EnumItem : Enums)
28+
if (EnumItem.Value == Value)
29+
return EnumItem.Name;
30+
return std::nullopt;
4031
}
4132

42-
static raw_ostream &operator<<(raw_ostream &OS,
43-
const ShaderVisibility &Visibility) {
44-
switch (Visibility) {
45-
case ShaderVisibility::All:
46-
OS << "All";
47-
break;
48-
case ShaderVisibility::Vertex:
49-
OS << "Vertex";
50-
break;
51-
case ShaderVisibility::Hull:
52-
OS << "Hull";
53-
break;
54-
case ShaderVisibility::Domain:
55-
OS << "Domain";
56-
break;
57-
case ShaderVisibility::Geometry:
58-
OS << "Geometry";
59-
break;
60-
case ShaderVisibility::Pixel:
61-
OS << "Pixel";
62-
break;
63-
case ShaderVisibility::Amplification:
64-
OS << "Amplification";
65-
break;
66-
case ShaderVisibility::Mesh:
67-
OS << "Mesh";
68-
break;
69-
}
70-
71-
return OS;
72-
}
73-
74-
static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
75-
switch (Type) {
76-
case ClauseType::CBuffer:
77-
OS << "CBV";
78-
break;
79-
case ClauseType::SRV:
80-
OS << "SRV";
81-
break;
82-
case ClauseType::UAV:
83-
OS << "UAV";
84-
break;
85-
case ClauseType::Sampler:
86-
OS << "Sampler";
87-
break;
88-
}
89-
33+
template <typename T>
34+
static raw_ostream &printEnum(raw_ostream &OS, const T Value,
35+
ArrayRef<EnumEntry<T>> Enums) {
36+
auto MaybeName = getEnumName(Value, Enums);
37+
if (MaybeName)
38+
OS << *MaybeName;
9039
return OS;
9140
}
9241

93-
static raw_ostream &operator<<(raw_ostream &OS,
94-
const DescriptorRangeFlags &Flags) {
42+
template <typename T>
43+
static raw_ostream &printFlags(raw_ostream &OS, const T Value,
44+
ArrayRef<EnumEntry<T>> Flags) {
9545
bool FlagSet = false;
96-
unsigned Remaining = llvm::to_underlying(Flags);
46+
unsigned Remaining = llvm::to_underlying(Value);
9747
while (Remaining) {
9848
unsigned Bit = 1u << llvm::countr_zero(Remaining);
9949
if (Remaining & Bit) {
10050
if (FlagSet)
10151
OS << " | ";
10252

103-
switch (static_cast<DescriptorRangeFlags>(Bit)) {
104-
case DescriptorRangeFlags::DescriptorsVolatile:
105-
OS << "DescriptorsVolatile";
106-
break;
107-
case DescriptorRangeFlags::DataVolatile:
108-
OS << "DataVolatile";
109-
break;
110-
case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
111-
OS << "DataStaticWhileSetAtExecute";
112-
break;
113-
case DescriptorRangeFlags::DataStatic:
114-
OS << "DataStatic";
115-
break;
116-
case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
117-
OS << "DescriptorsStaticKeepingBufferBoundsChecks";
118-
break;
119-
default:
53+
auto MaybeFlag = getEnumName(T(Bit), Flags);
54+
if (MaybeFlag)
55+
OS << *MaybeFlag;
56+
else
12057
OS << "invalid: " << Bit;
121-
break;
122-
}
12358

12459
FlagSet = true;
12560
}
@@ -128,6 +63,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
12863

12964
if (!FlagSet)
13065
OS << "None";
66+
return OS;
67+
}
68+
69+
static const EnumEntry<RegisterType> RegisterNames[] = {
70+
{"b", RegisterType::BReg},
71+
{"t", RegisterType::TReg},
72+
{"u", RegisterType::UReg},
73+
{"s", RegisterType::SReg},
74+
};
75+
76+
static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
77+
printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
78+
OS << Reg.Number;
79+
80+
return OS;
81+
}
82+
83+
static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
84+
{"All", ShaderVisibility::All},
85+
{"Vertex", ShaderVisibility::Vertex},
86+
{"Hull", ShaderVisibility::Hull},
87+
{"Domain", ShaderVisibility::Domain},
88+
{"Geometry", ShaderVisibility::Geometry},
89+
{"Pixel", ShaderVisibility::Pixel},
90+
{"Amplification", ShaderVisibility::Amplification},
91+
{"Mesh", ShaderVisibility::Mesh},
92+
};
93+
94+
static raw_ostream &operator<<(raw_ostream &OS,
95+
const ShaderVisibility &Visibility) {
96+
printEnum(OS, Visibility, ArrayRef(VisibilityNames));
97+
98+
return OS;
99+
}
100+
101+
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
102+
{"CBV", dxil::ResourceClass::CBuffer},
103+
{"SRV", dxil::ResourceClass::SRV},
104+
{"UAV", dxil::ResourceClass::UAV},
105+
{"Sampler", dxil::ResourceClass::Sampler},
106+
};
107+
108+
static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
109+
printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
110+
ArrayRef(ResourceClassNames));
111+
112+
return OS;
113+
}
114+
115+
static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
116+
{"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
117+
{"DataVolatile", DescriptorRangeFlags::DataVolatile},
118+
{"DataStaticWhileSetAtExecute",
119+
DescriptorRangeFlags::DataStaticWhileSetAtExecute},
120+
{"DataStatic", DescriptorRangeFlags::DataStatic},
121+
{"DescriptorsStaticKeepingBufferBoundsChecks",
122+
DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
123+
};
124+
125+
static raw_ostream &operator<<(raw_ostream &OS,
126+
const DescriptorRangeFlags &Flags) {
127+
printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
131128

132129
return OS;
133130
}
@@ -236,12 +233,13 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
236233

237234
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
238235
IRBuilder<> Builder(Ctx);
239-
llvm::SmallString<7> Name;
240-
llvm::raw_svector_ostream OS(Name);
241-
OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
242-
236+
std::optional<StringRef> TypeName =
237+
getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
238+
ArrayRef(ResourceClassNames));
239+
assert(TypeName && "Provided an invalid Resource Class");
240+
llvm::SmallString<7> Name({"Root", *TypeName});
243241
Metadata *Operands[] = {
244-
MDString::get(Ctx, OS.str()),
242+
MDString::get(Ctx, Name),
245243
ConstantAsMetadata::get(
246244
Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
247245
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -277,19 +275,20 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
277275
MDNode *MetadataBuilder::BuildDescriptorTableClause(
278276
const DescriptorTableClause &Clause) {
279277
IRBuilder<> Builder(Ctx);
280-
std::string Name;
281-
llvm::raw_string_ostream OS(Name);
282-
OS << Clause.Type;
283-
return MDNode::get(
284-
Ctx, {
285-
MDString::get(Ctx, OS.str()),
286-
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
287-
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
288-
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
289-
ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
290-
ConstantAsMetadata::get(
291-
Builder.getInt32(llvm::to_underlying(Clause.Flags))),
292-
});
278+
std::optional<StringRef> Name =
279+
getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
280+
ArrayRef(ResourceClassNames));
281+
assert(Name && "Provided an invalid Resource Class");
282+
Metadata *Operands[] = {
283+
MDString::get(Ctx, *Name),
284+
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
285+
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
286+
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
287+
ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
288+
ConstantAsMetadata::get(
289+
Builder.getInt32(llvm::to_underlying(Clause.Flags))),
290+
};
291+
return MDNode::get(Ctx, Operands);
293292
}
294293

295294
MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {

0 commit comments

Comments
 (0)