15
15
#include " llvm/ADT/bit.h"
16
16
#include " llvm/IR/IRBuilder.h"
17
17
#include " llvm/IR/Metadata.h"
18
+ #include " llvm/Support/ScopedPrinter.h"
18
19
19
20
namespace llvm {
20
21
namespace hlsl {
21
22
namespace rootsig {
22
23
23
- static raw_ostream &operator <<(raw_ostream &OS, const Register &Reg) {
24
- switch (Reg.ViewType ) {
25
- case RegisterType::BReg:
26
- OS << " b" ;
27
- break ;
28
- case RegisterType::TReg:
29
- OS << " t" ;
30
- break ;
31
- case RegisterType::UReg:
32
- OS << " u" ;
33
- break ;
34
- case RegisterType::SReg:
35
- OS << " s" ;
36
- break ;
37
- }
38
- OS << Reg.Number ;
39
- return OS;
24
+ template <typename T>
25
+ static std::optional<StringRef> getEnumName (const T Value,
26
+ ArrayRef<EnumEntry<T>> Enums) {
27
+ for (const auto &EnumItem : Enums)
28
+ if (EnumItem.Value == Value)
29
+ return EnumItem.Name ;
30
+ return std::nullopt;
40
31
}
41
32
42
- static raw_ostream &operator <<(raw_ostream &OS,
43
- const ShaderVisibility &Visibility) {
44
- switch (Visibility) {
45
- case ShaderVisibility::All:
46
- OS << " All" ;
47
- break ;
48
- case ShaderVisibility::Vertex:
49
- OS << " Vertex" ;
50
- break ;
51
- case ShaderVisibility::Hull:
52
- OS << " Hull" ;
53
- break ;
54
- case ShaderVisibility::Domain:
55
- OS << " Domain" ;
56
- break ;
57
- case ShaderVisibility::Geometry:
58
- OS << " Geometry" ;
59
- break ;
60
- case ShaderVisibility::Pixel:
61
- OS << " Pixel" ;
62
- break ;
63
- case ShaderVisibility::Amplification:
64
- OS << " Amplification" ;
65
- break ;
66
- case ShaderVisibility::Mesh:
67
- OS << " Mesh" ;
68
- break ;
69
- }
70
-
71
- return OS;
72
- }
73
-
74
- static raw_ostream &operator <<(raw_ostream &OS, const ClauseType &Type) {
75
- switch (Type) {
76
- case ClauseType::CBuffer:
77
- OS << " CBV" ;
78
- break ;
79
- case ClauseType::SRV:
80
- OS << " SRV" ;
81
- break ;
82
- case ClauseType::UAV:
83
- OS << " UAV" ;
84
- break ;
85
- case ClauseType::Sampler:
86
- OS << " Sampler" ;
87
- break ;
88
- }
89
-
33
+ template <typename T>
34
+ static raw_ostream &printEnum (raw_ostream &OS, const T Value,
35
+ ArrayRef<EnumEntry<T>> Enums) {
36
+ auto MaybeName = getEnumName (Value, Enums);
37
+ if (MaybeName)
38
+ OS << *MaybeName;
90
39
return OS;
91
40
}
92
41
93
- static raw_ostream &operator <<(raw_ostream &OS,
94
- const DescriptorRangeFlags &Flags) {
42
+ template <typename T>
43
+ static raw_ostream &printFlags (raw_ostream &OS, const T Value,
44
+ ArrayRef<EnumEntry<T>> Flags) {
95
45
bool FlagSet = false ;
96
- unsigned Remaining = llvm::to_underlying (Flags );
46
+ unsigned Remaining = llvm::to_underlying (Value );
97
47
while (Remaining) {
98
48
unsigned Bit = 1u << llvm::countr_zero (Remaining);
99
49
if (Remaining & Bit) {
100
50
if (FlagSet)
101
51
OS << " | " ;
102
52
103
- switch (static_cast <DescriptorRangeFlags>(Bit)) {
104
- case DescriptorRangeFlags::DescriptorsVolatile:
105
- OS << " DescriptorsVolatile" ;
106
- break ;
107
- case DescriptorRangeFlags::DataVolatile:
108
- OS << " DataVolatile" ;
109
- break ;
110
- case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
111
- OS << " DataStaticWhileSetAtExecute" ;
112
- break ;
113
- case DescriptorRangeFlags::DataStatic:
114
- OS << " DataStatic" ;
115
- break ;
116
- case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
117
- OS << " DescriptorsStaticKeepingBufferBoundsChecks" ;
118
- break ;
119
- default :
53
+ auto MaybeFlag = getEnumName (T (Bit), Flags);
54
+ if (MaybeFlag)
55
+ OS << *MaybeFlag;
56
+ else
120
57
OS << " invalid: " << Bit;
121
- break ;
122
- }
123
58
124
59
FlagSet = true ;
125
60
}
@@ -128,6 +63,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
128
63
129
64
if (!FlagSet)
130
65
OS << " None" ;
66
+ return OS;
67
+ }
68
+
69
+ static const EnumEntry<RegisterType> RegisterNames[] = {
70
+ {" b" , RegisterType::BReg},
71
+ {" t" , RegisterType::TReg},
72
+ {" u" , RegisterType::UReg},
73
+ {" s" , RegisterType::SReg},
74
+ };
75
+
76
+ static raw_ostream &operator <<(raw_ostream &OS, const Register &Reg) {
77
+ printEnum (OS, Reg.ViewType , ArrayRef (RegisterNames));
78
+ OS << Reg.Number ;
79
+
80
+ return OS;
81
+ }
82
+
83
+ static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
84
+ {" All" , ShaderVisibility::All},
85
+ {" Vertex" , ShaderVisibility::Vertex},
86
+ {" Hull" , ShaderVisibility::Hull},
87
+ {" Domain" , ShaderVisibility::Domain},
88
+ {" Geometry" , ShaderVisibility::Geometry},
89
+ {" Pixel" , ShaderVisibility::Pixel},
90
+ {" Amplification" , ShaderVisibility::Amplification},
91
+ {" Mesh" , ShaderVisibility::Mesh},
92
+ };
93
+
94
+ static raw_ostream &operator <<(raw_ostream &OS,
95
+ const ShaderVisibility &Visibility) {
96
+ printEnum (OS, Visibility, ArrayRef (VisibilityNames));
97
+
98
+ return OS;
99
+ }
100
+
101
+ static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
102
+ {" CBV" , dxil::ResourceClass::CBuffer},
103
+ {" SRV" , dxil::ResourceClass::SRV},
104
+ {" UAV" , dxil::ResourceClass::UAV},
105
+ {" Sampler" , dxil::ResourceClass::Sampler},
106
+ };
107
+
108
+ static raw_ostream &operator <<(raw_ostream &OS, const ClauseType &Type) {
109
+ printEnum (OS, dxil::ResourceClass (llvm::to_underlying (Type)),
110
+ ArrayRef (ResourceClassNames));
111
+
112
+ return OS;
113
+ }
114
+
115
+ static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
116
+ {" DescriptorsVolatile" , DescriptorRangeFlags::DescriptorsVolatile},
117
+ {" DataVolatile" , DescriptorRangeFlags::DataVolatile},
118
+ {" DataStaticWhileSetAtExecute" ,
119
+ DescriptorRangeFlags::DataStaticWhileSetAtExecute},
120
+ {" DataStatic" , DescriptorRangeFlags::DataStatic},
121
+ {" DescriptorsStaticKeepingBufferBoundsChecks" ,
122
+ DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
123
+ };
124
+
125
+ static raw_ostream &operator <<(raw_ostream &OS,
126
+ const DescriptorRangeFlags &Flags) {
127
+ printFlags (OS, Flags, ArrayRef (DescriptorRangeFlagNames));
131
128
132
129
return OS;
133
130
}
@@ -236,12 +233,13 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
236
233
237
234
MDNode *MetadataBuilder::BuildRootDescriptor (const RootDescriptor &Descriptor) {
238
235
IRBuilder<> Builder (Ctx);
239
- llvm::SmallString<7 > Name;
240
- llvm::raw_svector_ostream OS (Name);
241
- OS << " Root" << ClauseType (llvm::to_underlying (Descriptor.Type ));
242
-
236
+ std::optional<StringRef> TypeName =
237
+ getEnumName (dxil::ResourceClass (llvm::to_underlying (Descriptor.Type )),
238
+ ArrayRef (ResourceClassNames));
239
+ assert (TypeName && " Provided an invalid Resource Class" );
240
+ llvm::SmallString<7 > Name ({" Root" , *TypeName});
243
241
Metadata *Operands[] = {
244
- MDString::get (Ctx, OS. str () ),
242
+ MDString::get (Ctx, Name ),
245
243
ConstantAsMetadata::get (
246
244
Builder.getInt32 (llvm::to_underlying (Descriptor.Visibility ))),
247
245
ConstantAsMetadata::get (Builder.getInt32 (Descriptor.Reg .Number )),
@@ -277,19 +275,20 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
277
275
MDNode *MetadataBuilder::BuildDescriptorTableClause (
278
276
const DescriptorTableClause &Clause) {
279
277
IRBuilder<> Builder (Ctx);
280
- std::string Name;
281
- llvm::raw_string_ostream OS (Name);
282
- OS << Clause.Type ;
283
- return MDNode::get (
284
- Ctx, {
285
- MDString::get (Ctx, OS.str ()),
286
- ConstantAsMetadata::get (Builder.getInt32 (Clause.NumDescriptors )),
287
- ConstantAsMetadata::get (Builder.getInt32 (Clause.Reg .Number )),
288
- ConstantAsMetadata::get (Builder.getInt32 (Clause.Space )),
289
- ConstantAsMetadata::get (Builder.getInt32 (Clause.Offset )),
290
- ConstantAsMetadata::get (
291
- Builder.getInt32 (llvm::to_underlying (Clause.Flags ))),
292
- });
278
+ std::optional<StringRef> Name =
279
+ getEnumName (dxil::ResourceClass (llvm::to_underlying (Clause.Type )),
280
+ ArrayRef (ResourceClassNames));
281
+ assert (Name && " Provided an invalid Resource Class" );
282
+ Metadata *Operands[] = {
283
+ MDString::get (Ctx, *Name),
284
+ ConstantAsMetadata::get (Builder.getInt32 (Clause.NumDescriptors )),
285
+ ConstantAsMetadata::get (Builder.getInt32 (Clause.Reg .Number )),
286
+ ConstantAsMetadata::get (Builder.getInt32 (Clause.Space )),
287
+ ConstantAsMetadata::get (Builder.getInt32 (Clause.Offset )),
288
+ ConstantAsMetadata::get (
289
+ Builder.getInt32 (llvm::to_underlying (Clause.Flags ))),
290
+ };
291
+ return MDNode::get (Ctx, Operands);
293
292
}
294
293
295
294
MDNode *MetadataBuilder::BuildStaticSampler (const StaticSampler &Sampler) {
0 commit comments