Skip to content

Commit a38f10b

Browse files
author
joaosaffran
committed
refactoring mcdxbc struct to store root parameters out of order
1 parent a928e9d commit a38f10b

File tree

4 files changed

+201
-75
lines changed

4 files changed

+201
-75
lines changed

llvm/include/llvm/MC/DXContainerRootSignature.h

Lines changed: 131 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,146 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "llvm/ADT/STLForwardCompat.h"
910
#include "llvm/BinaryFormat/DXContainer.h"
11+
#include "llvm/Support/ErrorHandling.h"
12+
#include <cstddef>
1013
#include <cstdint>
11-
#include <limits>
14+
#include <variant>
1215

1316
namespace llvm {
1417

1518
class raw_ostream;
1619
namespace mcdxbc {
1720

21+
struct RootParameterHeader : public dxbc::RootParameterHeader {
22+
23+
size_t Location;
24+
25+
RootParameterHeader() = default;
26+
27+
RootParameterHeader(dxbc::RootParameterHeader H, size_t L)
28+
: dxbc::RootParameterHeader(H), Location(L) {}
29+
};
30+
31+
using RootDescriptor = std::variant<dxbc::RST0::v0::RootDescriptor,
32+
dxbc::RST0::v1::RootDescriptor>;
33+
using ParametersView =
34+
std::variant<dxbc::RootConstants, dxbc::RST0::v0::RootDescriptor,
35+
dxbc::RST0::v1::RootDescriptor>;
1836
struct RootParameter {
19-
dxbc::RootParameterHeader Header;
20-
union {
21-
dxbc::RootConstants Constants;
22-
dxbc::RST0::v1::RootDescriptor Descriptor;
37+
SmallVector<RootParameterHeader> Headers;
38+
39+
SmallVector<dxbc::RootConstants> Constants;
40+
SmallVector<RootDescriptor> Descriptors;
41+
42+
void addHeader(dxbc::RootParameterHeader H, size_t L) {
43+
Headers.push_back(RootParameterHeader(H, L));
44+
}
45+
46+
void addParameter(dxbc::RootParameterHeader H, dxbc::RootConstants C) {
47+
addHeader(H, Constants.size());
48+
Constants.push_back(C);
49+
}
50+
51+
void addParameter(dxbc::RootParameterHeader H,
52+
dxbc::RST0::v0::RootDescriptor D) {
53+
addHeader(H, Descriptors.size());
54+
Descriptors.push_back(D);
55+
}
56+
57+
void addParameter(dxbc::RootParameterHeader H,
58+
dxbc::RST0::v1::RootDescriptor D) {
59+
addHeader(H, Descriptors.size());
60+
Descriptors.push_back(D);
61+
}
62+
63+
ParametersView get(const RootParameterHeader &H) const {
64+
switch (H.ParameterType) {
65+
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
66+
return Constants[H.Location];
67+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
68+
case llvm::to_underlying(dxbc::RootParameterType::SRV):
69+
case llvm::to_underlying(dxbc::RootParameterType::UAV):
70+
RootDescriptor VersionedParam = Descriptors[H.Location];
71+
if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(
72+
VersionedParam))
73+
return std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam);
74+
return std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam);
75+
}
76+
77+
llvm_unreachable("Unimplemented parameter type");
78+
}
79+
80+
struct iterator {
81+
const RootParameter &Parameters;
82+
SmallVector<RootParameterHeader>::const_iterator Current;
83+
84+
// Changed parameter type to match member variable (removed const)
85+
iterator(const RootParameter &P,
86+
SmallVector<RootParameterHeader>::const_iterator C)
87+
: Parameters(P), Current(C) {}
88+
iterator(const iterator &) = default;
89+
90+
ParametersView operator*() {
91+
ParametersView Val;
92+
switch (Current->ParameterType) {
93+
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
94+
Val = Parameters.Constants[Current->Location];
95+
break;
96+
97+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
98+
case llvm::to_underlying(dxbc::RootParameterType::SRV):
99+
case llvm::to_underlying(dxbc::RootParameterType::UAV):
100+
RootDescriptor VersionedParam =
101+
Parameters.Descriptors[Current->Location];
102+
if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(
103+
VersionedParam))
104+
Val = std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam);
105+
else
106+
Val = std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam);
107+
break;
108+
}
109+
return Val;
110+
}
111+
112+
iterator operator++() {
113+
Current++;
114+
return *this;
115+
}
116+
117+
iterator operator++(int) {
118+
iterator Tmp = *this;
119+
++*this;
120+
return Tmp;
121+
}
122+
123+
iterator operator--() {
124+
Current--;
125+
return *this;
126+
}
127+
128+
iterator operator--(int) {
129+
iterator Tmp = *this;
130+
--*this;
131+
return Tmp;
132+
}
133+
134+
bool operator==(const iterator I) { return I.Current == Current; }
135+
bool operator!=(const iterator I) { return !(*this == I); }
23136
};
137+
138+
iterator begin() const { return iterator(*this, Headers.begin()); }
139+
140+
iterator end() const { return iterator(*this, Headers.end()); }
141+
142+
size_t size() const { return Headers.size(); }
143+
144+
bool isEmpty() const { return Headers.empty(); }
145+
146+
llvm::iterator_range<RootParameter::iterator> getAll() const {
147+
return llvm::make_range(begin(), end());
148+
}
24149
};
25150
struct RootSignatureDesc {
26151

@@ -29,7 +154,7 @@ struct RootSignatureDesc {
29154
uint32_t RootParameterOffset = 0U;
30155
uint32_t StaticSamplersOffset = 0u;
31156
uint32_t NumStaticSamplers = 0u;
32-
SmallVector<mcdxbc::RootParameter> Parameters;
157+
mcdxbc::RootParameter Parameters;
33158

34159
void write(raw_ostream &OS) const;
35160

llvm/lib/MC/DXContainerRootSignature.cpp

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
#include "llvm/MC/DXContainerRootSignature.h"
1010
#include "llvm/ADT/SmallString.h"
11+
#include "llvm/BinaryFormat/DXContainer.h"
1112
#include "llvm/Support/EndianStream.h"
13+
#include <variant>
1214

1315
using namespace llvm;
1416
using namespace llvm::mcdxbc;
@@ -32,22 +34,15 @@ size_t RootSignatureDesc::getSize() const {
3234
size_t Size = sizeof(dxbc::RootSignatureHeader) +
3335
Parameters.size() * sizeof(dxbc::RootParameterHeader);
3436

35-
for (const mcdxbc::RootParameter &P : Parameters) {
36-
switch (P.Header.ParameterType) {
37-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
38-
Size += sizeof(dxbc::RootConstants);
39-
break;
40-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
41-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
42-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
43-
if (Version == 1)
44-
Size += sizeof(dxbc::RST0::v0::RootDescriptor);
45-
else
46-
Size += sizeof(dxbc::RST0::v1::RootDescriptor);
47-
48-
break;
49-
}
37+
for (const auto &P : Parameters) {
38+
std::visit(
39+
[&Size](auto &Value) -> void {
40+
using T = std::decay_t<decltype(Value)>;
41+
Size += sizeof(T);
42+
},
43+
P);
5044
}
45+
5146
return Size;
5247
}
5348

@@ -66,39 +61,40 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
6661
support::endian::write(BOS, Flags, llvm::endianness::little);
6762

6863
SmallVector<uint32_t> ParamsOffsets;
69-
for (const mcdxbc::RootParameter &P : Parameters) {
70-
support::endian::write(BOS, P.Header.ParameterType,
71-
llvm::endianness::little);
72-
support::endian::write(BOS, P.Header.ShaderVisibility,
73-
llvm::endianness::little);
64+
for (const auto &P : Parameters.Headers) {
65+
support::endian::write(BOS, P.ParameterType, llvm::endianness::little);
66+
support::endian::write(BOS, P.ShaderVisibility, llvm::endianness::little);
7467

7568
ParamsOffsets.push_back(writePlaceholder(BOS));
7669
}
7770

7871
assert(NumParameters == ParamsOffsets.size());
79-
for (size_t I = 0; I < NumParameters; ++I) {
72+
auto P = Parameters.begin();
73+
for (size_t I = 0; I < NumParameters; ++I, P++) {
8074
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
81-
const mcdxbc::RootParameter &P = Parameters[I];
8275

83-
switch (P.Header.ParameterType) {
84-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
85-
support::endian::write(BOS, P.Constants.ShaderRegister,
76+
if (std::holds_alternative<dxbc::RootConstants>(*P)) {
77+
auto Constants = std::get<dxbc::RootConstants>(*P);
78+
support::endian::write(BOS, Constants.ShaderRegister,
8679
llvm::endianness::little);
87-
support::endian::write(BOS, P.Constants.RegisterSpace,
80+
support::endian::write(BOS, Constants.RegisterSpace,
8881
llvm::endianness::little);
89-
support::endian::write(BOS, P.Constants.Num32BitValues,
82+
support::endian::write(BOS, Constants.Num32BitValues,
9083
llvm::endianness::little);
91-
break;
92-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
93-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
94-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
95-
support::endian::write(BOS, P.Descriptor.ShaderRegister,
84+
} else if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(*P)) {
85+
auto Descriptor = std::get<dxbc::RST0::v0::RootDescriptor>(*P);
86+
support::endian::write(BOS, Descriptor.ShaderRegister,
87+
llvm::endianness::little);
88+
support::endian::write(BOS, Descriptor.RegisterSpace,
89+
llvm::endianness::little);
90+
} else if (std::holds_alternative<dxbc::RST0::v1::RootDescriptor>(*P)) {
91+
auto Descriptor = std::get<dxbc::RST0::v1::RootDescriptor>(*P);
92+
93+
support::endian::write(BOS, Descriptor.ShaderRegister,
9694
llvm::endianness::little);
97-
support::endian::write(BOS, P.Descriptor.RegisterSpace,
95+
support::endian::write(BOS, Descriptor.RegisterSpace,
9896
llvm::endianness::little);
99-
if (Version > 1)
100-
support::endian::write(BOS, P.Descriptor.Flags,
101-
llvm::endianness::little);
97+
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
10298
}
10399
}
104100
assert(Storage.size() == getSize());

llvm/lib/ObjectYAML/DXContainerEmitter.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,27 +274,31 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
274274
RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
275275

276276
for (const auto &Param : P.RootSignature->Parameters) {
277-
mcdxbc::RootParameter NewParam;
278-
NewParam.Header = dxbc::RootParameterHeader{
279-
Param.Type, Param.Visibility, Param.Offset};
277+
auto Header = dxbc::RootParameterHeader{Param.Type, Param.Visibility,
278+
Param.Offset};
280279

281280
switch (Param.Type) {
282281
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
283-
NewParam.Constants.Num32BitValues = Param.Constants.Num32BitValues;
284-
NewParam.Constants.RegisterSpace = Param.Constants.RegisterSpace;
285-
NewParam.Constants.ShaderRegister = Param.Constants.ShaderRegister;
282+
dxbc::RootConstants Constants;
283+
Constants.Num32BitValues = Param.Constants.Num32BitValues;
284+
Constants.RegisterSpace = Param.Constants.RegisterSpace;
285+
Constants.ShaderRegister = Param.Constants.ShaderRegister;
286+
RS.Parameters.addParameter(Header, Constants);
286287
break;
287288
case llvm::to_underlying(dxbc::RootParameterType::SRV):
288289
case llvm::to_underlying(dxbc::RootParameterType::UAV):
289290
case llvm::to_underlying(dxbc::RootParameterType::CBV):
290-
NewParam.Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
291-
NewParam.Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
291+
dxbc::RST0::v1::RootDescriptor Descriptor;
292+
Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
293+
Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
292294
if (P.RootSignature->Version > 1)
293-
NewParam.Descriptor.Flags = Param.Descriptor.getEncodedFlags();
295+
Descriptor.Flags = Param.Descriptor.getEncodedFlags();
296+
RS.Parameters.addParameter(Header, Descriptor);
294297
break;
298+
default:
299+
// Handling invalid parameter type edge case
300+
RS.Parameters.addHeader(Header, -1);
295301
}
296-
297-
RS.Parameters.push_back(NewParam);
298302
}
299303

300304
RS.write(OS);

0 commit comments

Comments
 (0)