Skip to content

Commit 8d63afb

Browse files
joaosaffranjoaosaffran
andauthored
[NFC] Refactoring MCDXBC to support out of order storage of root parameters (#137284)
This PR refactors mcdxbc data structure for root signatures to support out of order storage of in memory root signature data. closes: #139585 --------- Co-authored-by: joaosaffran <[email protected]>
1 parent f3f2832 commit 8d63afb

File tree

5 files changed

+140
-61
lines changed

5 files changed

+140
-61
lines changed

llvm/include/llvm/BinaryFormat/DXContainer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ struct RootDescriptor : public v1::RootDescriptor {
602602
uint32_t Flags;
603603

604604
RootDescriptor() = default;
605-
RootDescriptor(v1::RootDescriptor &Base)
605+
explicit RootDescriptor(v1::RootDescriptor &Base)
606606
: v1::RootDescriptor(Base), Flags(0u) {}
607607

608608
void swapBytes() {

llvm/include/llvm/MC/DXContainerRootSignature.h

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,69 @@ namespace llvm {
1515
class raw_ostream;
1616
namespace mcdxbc {
1717

18-
struct RootParameter {
18+
struct RootParameterInfo {
1919
dxbc::RootParameterHeader Header;
20-
union {
21-
dxbc::RootConstants Constants;
22-
dxbc::RTS0::v2::RootDescriptor Descriptor;
23-
};
20+
size_t Location;
21+
22+
RootParameterInfo() = default;
23+
24+
RootParameterInfo(dxbc::RootParameterHeader Header, size_t Location)
25+
: Header(Header), Location(Location) {}
26+
};
27+
28+
struct RootParametersContainer {
29+
SmallVector<RootParameterInfo> ParametersInfo;
30+
31+
SmallVector<dxbc::RootConstants> Constants;
32+
SmallVector<dxbc::RTS0::v2::RootDescriptor> Descriptors;
33+
34+
void addInfo(dxbc::RootParameterHeader Header, size_t Location) {
35+
ParametersInfo.push_back(RootParameterInfo(Header, Location));
36+
}
37+
38+
void addParameter(dxbc::RootParameterHeader Header,
39+
dxbc::RootConstants Constant) {
40+
addInfo(Header, Constants.size());
41+
Constants.push_back(Constant);
42+
}
43+
44+
void addInvalidParameter(dxbc::RootParameterHeader Header) {
45+
addInfo(Header, -1);
46+
}
47+
48+
void addParameter(dxbc::RootParameterHeader Header,
49+
dxbc::RTS0::v2::RootDescriptor Descriptor) {
50+
addInfo(Header, Descriptors.size());
51+
Descriptors.push_back(Descriptor);
52+
}
53+
54+
const std::pair<uint32_t, uint32_t>
55+
getTypeAndLocForParameter(uint32_t Location) const {
56+
const RootParameterInfo &Info = ParametersInfo[Location];
57+
return {Info.Header.ParameterType, Info.Location};
58+
}
59+
60+
const dxbc::RootParameterHeader &getHeader(size_t Location) const {
61+
const RootParameterInfo &Info = ParametersInfo[Location];
62+
return Info.Header;
63+
}
64+
65+
const dxbc::RootConstants &getConstant(size_t Index) const {
66+
return Constants[Index];
67+
}
68+
69+
const dxbc::RTS0::v2::RootDescriptor &getRootDescriptor(size_t Index) const {
70+
return Descriptors[Index];
71+
}
72+
73+
size_t size() const { return ParametersInfo.size(); }
74+
75+
SmallVector<RootParameterInfo>::const_iterator begin() const {
76+
return ParametersInfo.begin();
77+
}
78+
SmallVector<RootParameterInfo>::const_iterator end() const {
79+
return ParametersInfo.end();
80+
}
2481
};
2582
struct RootSignatureDesc {
2683

@@ -29,7 +86,7 @@ struct RootSignatureDesc {
2986
uint32_t RootParameterOffset = 0U;
3087
uint32_t StaticSamplersOffset = 0u;
3188
uint32_t NumStaticSamplers = 0u;
32-
SmallVector<mcdxbc::RootParameter> Parameters;
89+
mcdxbc::RootParametersContainer ParametersContainer;
3390

3491
void write(raw_ostream &OS) const;
3592

llvm/lib/MC/DXContainerRootSignature.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ static void rewriteOffsetToCurrentByte(raw_svector_ostream &Stream,
3030

3131
size_t RootSignatureDesc::getSize() const {
3232
size_t Size = sizeof(dxbc::RootSignatureHeader) +
33-
Parameters.size() * sizeof(dxbc::RootParameterHeader);
33+
ParametersContainer.size() * sizeof(dxbc::RootParameterHeader);
3434

35-
for (const mcdxbc::RootParameter &P : Parameters) {
36-
switch (P.Header.ParameterType) {
35+
for (const RootParameterInfo &I : ParametersContainer) {
36+
switch (I.Header.ParameterType) {
3737
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
3838
Size += sizeof(dxbc::RootConstants);
3939
break;
@@ -56,7 +56,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
5656
raw_svector_ostream BOS(Storage);
5757
BOS.reserveExtraSpace(getSize());
5858

59-
const uint32_t NumParameters = Parameters.size();
59+
const uint32_t NumParameters = ParametersContainer.size();
6060

6161
support::endian::write(BOS, Version, llvm::endianness::little);
6262
support::endian::write(BOS, NumParameters, llvm::endianness::little);
@@ -66,7 +66,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
6666
support::endian::write(BOS, Flags, llvm::endianness::little);
6767

6868
SmallVector<uint32_t> ParamsOffsets;
69-
for (const mcdxbc::RootParameter &P : Parameters) {
69+
for (const RootParameterInfo &P : ParametersContainer) {
7070
support::endian::write(BOS, P.Header.ParameterType,
7171
llvm::endianness::little);
7272
support::endian::write(BOS, P.Header.ShaderVisibility,
@@ -78,27 +78,33 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
7878
assert(NumParameters == ParamsOffsets.size());
7979
for (size_t I = 0; I < NumParameters; ++I) {
8080
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
81-
const mcdxbc::RootParameter &P = Parameters[I];
82-
83-
switch (P.Header.ParameterType) {
84-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
85-
support::endian::write(BOS, P.Constants.ShaderRegister,
81+
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
82+
switch (Type) {
83+
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
84+
const dxbc::RootConstants &Constants =
85+
ParametersContainer.getConstant(Loc);
86+
support::endian::write(BOS, Constants.ShaderRegister,
8687
llvm::endianness::little);
87-
support::endian::write(BOS, P.Constants.RegisterSpace,
88+
support::endian::write(BOS, Constants.RegisterSpace,
8889
llvm::endianness::little);
89-
support::endian::write(BOS, P.Constants.Num32BitValues,
90+
support::endian::write(BOS, Constants.Num32BitValues,
9091
llvm::endianness::little);
9192
break;
93+
}
9294
case llvm::to_underlying(dxbc::RootParameterType::CBV):
9395
case llvm::to_underlying(dxbc::RootParameterType::SRV):
94-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
95-
support::endian::write(BOS, P.Descriptor.ShaderRegister,
96+
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
97+
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
98+
ParametersContainer.getRootDescriptor(Loc);
99+
100+
support::endian::write(BOS, Descriptor.ShaderRegister,
96101
llvm::endianness::little);
97-
support::endian::write(BOS, P.Descriptor.RegisterSpace,
102+
support::endian::write(BOS, Descriptor.RegisterSpace,
98103
llvm::endianness::little);
99104
if (Version > 1)
100-
support::endian::write(BOS, P.Descriptor.Flags,
101-
llvm::endianness::little);
105+
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
106+
break;
107+
}
102108
}
103109
}
104110
assert(Storage.size() == getSize());

llvm/lib/ObjectYAML/DXContainerEmitter.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -274,27 +274,33 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
274274
RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
275275

276276
for (const auto &Param : P.RootSignature->Parameters) {
277-
mcdxbc::RootParameter NewParam;
278-
NewParam.Header = dxbc::RootParameterHeader{
279-
Param.Type, Param.Visibility, Param.Offset};
277+
dxbc::RootParameterHeader Header{Param.Type, Param.Visibility,
278+
Param.Offset};
280279

281280
switch (Param.Type) {
282281
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
283-
NewParam.Constants.Num32BitValues = Param.Constants.Num32BitValues;
284-
NewParam.Constants.RegisterSpace = Param.Constants.RegisterSpace;
285-
NewParam.Constants.ShaderRegister = Param.Constants.ShaderRegister;
282+
dxbc::RootConstants Constants;
283+
Constants.Num32BitValues = Param.Constants.Num32BitValues;
284+
Constants.RegisterSpace = Param.Constants.RegisterSpace;
285+
Constants.ShaderRegister = Param.Constants.ShaderRegister;
286+
RS.ParametersContainer.addParameter(Header, Constants);
286287
break;
287288
case llvm::to_underlying(dxbc::RootParameterType::SRV):
288289
case llvm::to_underlying(dxbc::RootParameterType::UAV):
289290
case llvm::to_underlying(dxbc::RootParameterType::CBV):
290-
NewParam.Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
291-
NewParam.Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
292-
if (P.RootSignature->Version > 1)
293-
NewParam.Descriptor.Flags = Param.Descriptor.getEncodedFlags();
291+
dxbc::RTS0::v2::RootDescriptor Descriptor;
292+
Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace;
293+
Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister;
294+
if (RS.Version > 1)
295+
Descriptor.Flags = Param.Descriptor.getEncodedFlags();
296+
RS.ParametersContainer.addParameter(Header, Descriptor);
294297
break;
298+
default:
299+
// Handling invalid parameter type edge case. We intentionally let
300+
// obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
301+
// for that to be used as a testing tool more effectively.
302+
RS.ParametersContainer.addInvalidParameter(Header);
295303
}
296-
297-
RS.Parameters.push_back(NewParam);
298304
}
299305

300306
RS.write(OS);

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,31 +75,34 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
7575
if (RootConstantNode->getNumOperands() != 5)
7676
return reportError(Ctx, "Invalid format for RootConstants Element");
7777

78-
mcdxbc::RootParameter NewParameter;
79-
NewParameter.Header.ParameterType =
78+
dxbc::RootParameterHeader Header;
79+
// The parameter offset doesn't matter here - we recalculate it during
80+
// serialization Header.ParameterOffset = 0;
81+
Header.ParameterType =
8082
llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
8183

8284
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
83-
NewParameter.Header.ShaderVisibility = *Val;
85+
Header.ShaderVisibility = *Val;
8486
else
8587
return reportError(Ctx, "Invalid value for ShaderVisibility");
8688

89+
dxbc::RootConstants Constants;
8790
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
88-
NewParameter.Constants.ShaderRegister = *Val;
91+
Constants.ShaderRegister = *Val;
8992
else
9093
return reportError(Ctx, "Invalid value for ShaderRegister");
9194

9295
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
93-
NewParameter.Constants.RegisterSpace = *Val;
96+
Constants.RegisterSpace = *Val;
9497
else
9598
return reportError(Ctx, "Invalid value for RegisterSpace");
9699

97100
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
98-
NewParameter.Constants.Num32BitValues = *Val;
101+
Constants.Num32BitValues = *Val;
99102
else
100103
return reportError(Ctx, "Invalid value for Num32BitValues");
101104

102-
RSD.Parameters.push_back(NewParameter);
105+
RSD.ParametersContainer.addParameter(Header, Constants);
103106

104107
return false;
105108
}
@@ -164,12 +167,12 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
164167
return reportValueError(Ctx, "RootFlags", RSD.Flags);
165168
}
166169

167-
for (const mcdxbc::RootParameter &P : RSD.Parameters) {
168-
if (!dxbc::isValidShaderVisibility(P.Header.ShaderVisibility))
170+
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
171+
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
169172
return reportValueError(Ctx, "ShaderVisibility",
170-
P.Header.ShaderVisibility);
173+
Info.Header.ShaderVisibility);
171174

172-
assert(dxbc::isValidParameterType(P.Header.ParameterType) &&
175+
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
173176
"Invalid value for ParameterType");
174177
}
175178

@@ -287,33 +290,40 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
287290
OS << indent(Space) << "Version: " << RS.Version << "\n";
288291
OS << indent(Space) << "RootParametersOffset: " << RS.RootParameterOffset
289292
<< "\n";
290-
OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n";
293+
OS << indent(Space) << "NumParameters: " << RS.ParametersContainer.size()
294+
<< "\n";
291295
Space++;
292-
for (auto const &P : RS.Parameters) {
293-
OS << indent(Space) << "- Parameter Type: " << P.Header.ParameterType
294-
<< "\n";
296+
for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
297+
const auto &[Type, Loc] =
298+
RS.ParametersContainer.getTypeAndLocForParameter(I);
299+
const dxbc::RootParameterHeader Header =
300+
RS.ParametersContainer.getHeader(I);
301+
302+
OS << indent(Space) << "- Parameter Type: " << Type << "\n";
295303
OS << indent(Space + 2)
296-
<< "Shader Visibility: " << P.Header.ShaderVisibility << "\n";
297-
switch (P.Header.ParameterType) {
298-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
299-
OS << indent(Space + 2)
300-
<< "Register Space: " << P.Constants.RegisterSpace << "\n";
304+
<< "Shader Visibility: " << Header.ShaderVisibility << "\n";
305+
306+
switch (Type) {
307+
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
308+
const dxbc::RootConstants &Constants =
309+
RS.ParametersContainer.getConstant(Loc);
310+
OS << indent(Space + 2) << "Register Space: " << Constants.RegisterSpace
311+
<< "\n";
301312
OS << indent(Space + 2)
302-
<< "Shader Register: " << P.Constants.ShaderRegister << "\n";
313+
<< "Shader Register: " << Constants.ShaderRegister << "\n";
303314
OS << indent(Space + 2)
304-
<< "Num 32 Bit Values: " << P.Constants.Num32BitValues << "\n";
305-
break;
315+
<< "Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
306316
}
317+
}
318+
Space--;
307319
}
308-
Space--;
309320
OS << indent(Space) << "NumStaticSamplers: " << 0 << "\n";
310321
OS << indent(Space) << "StaticSamplersOffset: " << RS.StaticSamplersOffset
311322
<< "\n";
312323

313324
Space--;
314325
// end root signature header
315326
}
316-
317327
return PreservedAnalyses::all();
318328
}
319329

0 commit comments

Comments
 (0)