Skip to content

Commit 7219ed4

Browse files
author
joaosaffran
committed
refactoring mcdxbc struct to store root parameters out of order
1 parent 8b8c02a commit 7219ed4

File tree

5 files changed

+208
-91
lines changed

5 files changed

+208
-91
lines changed

llvm/include/llvm/MC/DXContainerRootSignature.h

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,146 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "llvm/ADT/STLForwardCompat.h"
910
#include "llvm/BinaryFormat/DXContainer.h"
11+
#include "llvm/Support/ErrorHandling.h"
12+
#include <cstddef>
1013
#include <cstdint>
11-
#include <limits>
14+
#include <variant>
1215

1316
namespace llvm {
1417

1518
class raw_ostream;
1619
namespace mcdxbc {
1720

21+
struct RootParameterHeader : public dxbc::RootParameterHeader {
22+
23+
size_t Location;
24+
25+
RootParameterHeader() = default;
26+
27+
RootParameterHeader(dxbc::RootParameterHeader H, size_t L)
28+
: dxbc::RootParameterHeader(H), Location(L) {}
29+
};
30+
31+
using RootDescriptor = std::variant<dxbc::RST0::v0::RootDescriptor,
32+
dxbc::RST0::v1::RootDescriptor>;
33+
using ParametersView =
34+
std::variant<dxbc::RootConstants, dxbc::RST0::v0::RootDescriptor,
35+
dxbc::RST0::v1::RootDescriptor>;
1836
struct RootParameter {
19-
dxbc::RootParameterHeader Header;
20-
union {
21-
dxbc::RootConstants Constants;
22-
dxbc::RST0::v0::RootDescriptor Descriptor_V10;
23-
dxbc::RST0::v1::RootDescriptor Descriptor_V11;
37+
SmallVector<RootParameterHeader> Headers;
38+
39+
SmallVector<dxbc::RootConstants> Constants;
40+
SmallVector<RootDescriptor> Descriptors;
41+
42+
void addHeader(dxbc::RootParameterHeader H, size_t L) {
43+
Headers.push_back(RootParameterHeader(H, L));
44+
}
45+
46+
void addParameter(dxbc::RootParameterHeader H, dxbc::RootConstants C) {
47+
addHeader(H, Constants.size());
48+
Constants.push_back(C);
49+
}
50+
51+
void addParameter(dxbc::RootParameterHeader H,
52+
dxbc::RST0::v0::RootDescriptor D) {
53+
addHeader(H, Descriptors.size());
54+
Descriptors.push_back(D);
55+
}
56+
57+
void addParameter(dxbc::RootParameterHeader H,
58+
dxbc::RST0::v1::RootDescriptor D) {
59+
addHeader(H, Descriptors.size());
60+
Descriptors.push_back(D);
61+
}
62+
63+
ParametersView get(const RootParameterHeader &H) const {
64+
switch (H.ParameterType) {
65+
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
66+
return Constants[H.Location];
67+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
68+
case llvm::to_underlying(dxbc::RootParameterType::SRV):
69+
case llvm::to_underlying(dxbc::RootParameterType::UAV):
70+
RootDescriptor VersionedParam = Descriptors[H.Location];
71+
if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(
72+
VersionedParam))
73+
return std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam);
74+
return std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam);
75+
}
76+
77+
llvm_unreachable("Unimplemented parameter type");
78+
}
79+
80+
struct iterator {
81+
const RootParameter &Parameters;
82+
SmallVector<RootParameterHeader>::const_iterator Current;
83+
84+
// Changed parameter type to match member variable (removed const)
85+
iterator(const RootParameter &P,
86+
SmallVector<RootParameterHeader>::const_iterator C)
87+
: Parameters(P), Current(C) {}
88+
iterator(const iterator &) = default;
89+
90+
ParametersView operator*() {
91+
ParametersView Val;
92+
switch (Current->ParameterType) {
93+
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
94+
Val = Parameters.Constants[Current->Location];
95+
break;
96+
97+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
98+
case llvm::to_underlying(dxbc::RootParameterType::SRV):
99+
case llvm::to_underlying(dxbc::RootParameterType::UAV):
100+
RootDescriptor VersionedParam =
101+
Parameters.Descriptors[Current->Location];
102+
if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(
103+
VersionedParam))
104+
Val = std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam);
105+
else
106+
Val = std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam);
107+
break;
108+
}
109+
return Val;
110+
}
111+
112+
iterator operator++() {
113+
Current++;
114+
return *this;
115+
}
116+
117+
iterator operator++(int) {
118+
iterator Tmp = *this;
119+
++*this;
120+
return Tmp;
121+
}
122+
123+
iterator operator--() {
124+
Current--;
125+
return *this;
126+
}
127+
128+
iterator operator--(int) {
129+
iterator Tmp = *this;
130+
--*this;
131+
return Tmp;
132+
}
133+
134+
bool operator==(const iterator I) { return I.Current == Current; }
135+
bool operator!=(const iterator I) { return !(*this == I); }
24136
};
137+
138+
iterator begin() const { return iterator(*this, Headers.begin()); }
139+
140+
iterator end() const { return iterator(*this, Headers.end()); }
141+
142+
size_t size() const { return Headers.size(); }
143+
144+
bool isEmpty() const { return Headers.empty(); }
145+
146+
llvm::iterator_range<RootParameter::iterator> getAll() const {
147+
return llvm::make_range(begin(), end());
148+
}
25149
};
26150
struct RootSignatureDesc {
27151

@@ -30,7 +154,7 @@ struct RootSignatureDesc {
30154
uint32_t RootParameterOffset = 0U;
31155
uint32_t StaticSamplersOffset = 0u;
32156
uint32_t NumStaticSamplers = 0u;
33-
SmallVector<mcdxbc::RootParameter> Parameters;
157+
mcdxbc::RootParameter Parameters;
34158

35159
void write(raw_ostream &OS) const;
36160

llvm/include/llvm/ObjectYAML/DXContainerYAML.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ struct RootParameterYamlDesc {
9595
uint32_t Type;
9696
uint32_t Visibility;
9797
uint32_t Offset;
98-
RootParameterYamlDesc() {};
98+
RootParameterYamlDesc(){};
9999
RootParameterYamlDesc(uint32_t T) : Type(T) {
100100
switch (T) {
101101

llvm/lib/MC/DXContainerRootSignature.cpp

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
#include "llvm/MC/DXContainerRootSignature.h"
1010
#include "llvm/ADT/SmallString.h"
11+
#include "llvm/BinaryFormat/DXContainer.h"
1112
#include "llvm/Support/EndianStream.h"
13+
#include <variant>
1214

1315
using namespace llvm;
1416
using namespace llvm::mcdxbc;
@@ -32,22 +34,15 @@ size_t RootSignatureDesc::getSize() const {
3234
size_t Size = sizeof(dxbc::RootSignatureHeader) +
3335
Parameters.size() * sizeof(dxbc::RootParameterHeader);
3436

35-
for (const mcdxbc::RootParameter &P : Parameters) {
36-
switch (P.Header.ParameterType) {
37-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
38-
Size += sizeof(dxbc::RootConstants);
39-
break;
40-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
41-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
42-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
43-
if (Version == 1)
44-
Size += sizeof(dxbc::RST0::v0::RootDescriptor);
45-
else
46-
Size += sizeof(dxbc::RST0::v1::RootDescriptor);
47-
48-
break;
49-
}
37+
for (const auto &P : Parameters) {
38+
std::visit(
39+
[&Size](auto &Value) -> void {
40+
using T = std::decay_t<decltype(Value)>;
41+
Size += sizeof(T);
42+
},
43+
P);
5044
}
45+
5146
return Size;
5247
}
5348

@@ -66,45 +61,40 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
6661
support::endian::write(BOS, Flags, llvm::endianness::little);
6762

6863
SmallVector<uint32_t> ParamsOffsets;
69-
for (const mcdxbc::RootParameter &P : Parameters) {
70-
support::endian::write(BOS, P.Header.ParameterType,
71-
llvm::endianness::little);
72-
support::endian::write(BOS, P.Header.ShaderVisibility,
73-
llvm::endianness::little);
64+
for (const auto &P : Parameters.Headers) {
65+
support::endian::write(BOS, P.ParameterType, llvm::endianness::little);
66+
support::endian::write(BOS, P.ShaderVisibility, llvm::endianness::little);
7467

7568
ParamsOffsets.push_back(writePlaceholder(BOS));
7669
}
7770

7871
assert(NumParameters == ParamsOffsets.size());
79-
for (size_t I = 0; I < NumParameters; ++I) {
72+
auto P = Parameters.begin();
73+
for (size_t I = 0; I < NumParameters; ++I, P++) {
8074
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
81-
const mcdxbc::RootParameter &P = Parameters[I];
8275

83-
switch (P.Header.ParameterType) {
84-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
85-
support::endian::write(BOS, P.Constants.ShaderRegister,
76+
if (std::holds_alternative<dxbc::RootConstants>(*P)) {
77+
auto Constants = std::get<dxbc::RootConstants>(*P);
78+
support::endian::write(BOS, Constants.ShaderRegister,
79+
llvm::endianness::little);
80+
support::endian::write(BOS, Constants.RegisterSpace,
8681
llvm::endianness::little);
87-
support::endian::write(BOS, P.Constants.RegisterSpace,
82+
support::endian::write(BOS, Constants.Num32BitValues,
83+
llvm::endianness::little);
84+
} else if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>(*P)) {
85+
auto Descriptor = std::get<dxbc::RST0::v0::RootDescriptor>(*P);
86+
support::endian::write(BOS, Descriptor.ShaderRegister,
87+
llvm::endianness::little);
88+
support::endian::write(BOS, Descriptor.RegisterSpace,
89+
llvm::endianness::little);
90+
} else if (std::holds_alternative<dxbc::RST0::v1::RootDescriptor>(*P)) {
91+
auto Descriptor = std::get<dxbc::RST0::v1::RootDescriptor>(*P);
92+
93+
support::endian::write(BOS, Descriptor.ShaderRegister,
8894
llvm::endianness::little);
89-
support::endian::write(BOS, P.Constants.Num32BitValues,
95+
support::endian::write(BOS, Descriptor.RegisterSpace,
9096
llvm::endianness::little);
91-
break;
92-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
93-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
94-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
95-
if (Version == 1) {
96-
support::endian::write(BOS, P.Descriptor_V10.ShaderRegister,
97-
llvm::endianness::little);
98-
support::endian::write(BOS, P.Descriptor_V10.RegisterSpace,
99-
llvm::endianness::little);
100-
} else {
101-
support::endian::write(BOS, P.Descriptor_V11.ShaderRegister,
102-
llvm::endianness::little);
103-
support::endian::write(BOS, P.Descriptor_V11.RegisterSpace,
104-
llvm::endianness::little);
105-
support::endian::write(BOS, P.Descriptor_V11.Flags,
106-
llvm::endianness::little);
107-
}
97+
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
10898
}
10999
}
110100
assert(Storage.size() == getSize());

llvm/lib/ObjectYAML/DXContainerEmitter.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -274,36 +274,38 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
274274
RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
275275

276276
for (const auto &Param : P.RootSignature->Parameters) {
277-
mcdxbc::RootParameter NewParam;
278-
NewParam.Header = dxbc::RootParameterHeader{
279-
Param.Type, Param.Visibility, Param.Offset};
277+
auto Header = dxbc::RootParameterHeader{Param.Type, Param.Visibility,
278+
Param.Offset};
280279

281280
switch (Param.Type) {
282281
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
283-
NewParam.Constants.Num32BitValues = Param.Constants.Num32BitValues;
284-
NewParam.Constants.RegisterSpace = Param.Constants.RegisterSpace;
285-
NewParam.Constants.ShaderRegister = Param.Constants.ShaderRegister;
282+
dxbc::RootConstants Constants;
283+
Constants.Num32BitValues = Param.Constants.Num32BitValues;
284+
Constants.RegisterSpace = Param.Constants.RegisterSpace;
285+
Constants.ShaderRegister = Param.Constants.ShaderRegister;
286+
RS.Parameters.addParameter(Header, Constants);
286287
break;
287288
case llvm::to_underlying(dxbc::RootParameterType::SRV):
288289
case llvm::to_underlying(dxbc::RootParameterType::UAV):
289290
case llvm::to_underlying(dxbc::RootParameterType::CBV):
290291
if (RS.Version == 1) {
291-
NewParam.Descriptor_V10.RegisterSpace =
292-
Param.Descriptor.RegisterSpace;
293-
NewParam.Descriptor_V10.ShaderRegister =
294-
Param.Descriptor.ShaderRegister;
292+
dxbc::RST0::v0::RootDescriptor Descriptor;
293+
Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
294+
Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
295+
RS.Parameters.addParameter(Header, Descriptor);
295296
} else {
296-
NewParam.Descriptor_V11.RegisterSpace =
297-
Param.Descriptor.RegisterSpace;
298-
NewParam.Descriptor_V11.ShaderRegister =
299-
Param.Descriptor.ShaderRegister;
300-
NewParam.Descriptor_V11.Flags = Param.Descriptor.getEncodedFlags();
297+
dxbc::RST0::v1::RootDescriptor Descriptor;
298+
Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
299+
Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
300+
Descriptor.Flags = Param.Descriptor.getEncodedFlags();
301+
RS.Parameters.addParameter(Header, Descriptor);
301302
}
302303

303304
break;
305+
default:
306+
// Handling invalid parameter type edge case
307+
RS.Parameters.addHeader(Header, -1);
304308
}
305-
306-
RS.Parameters.push_back(NewParam);
307309
}
308310

309311
RS.write(OS);

0 commit comments

Comments
 (0)