Skip to content

Commit a4eb0db

Browse files
authored
[HLSL][RootSignature] Add metadata generation for descriptor tables (#139633)
- prereq: Modify `RootSignatureAttr` to hold a reference to the owned declaration - Define and implement `MetadataBuilder` in `HLSLRootSignature` - Integrate and invoke the builder in `CGHLSLRuntime.cpp` to generate the Root Signature for any associated entry functions - Add tests to demonstrate functionality in `RootSignature.hlsl` Resolves #126584 Note: this is essentially just #125131 rebased onto the new approach of constructing a root signature decl, instead of holding the elements in `AdditionalMembers`.
1 parent 2e6433b commit a4eb0db

File tree

6 files changed

+164
-7
lines changed

6 files changed

+164
-7
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4739,7 +4739,8 @@ def Error : InheritableAttr {
47394739
def RootSignature : Attr {
47404740
/// [RootSignature(Signature)]
47414741
let Spellings = [Microsoft<"RootSignature">];
4742-
let Args = [IdentifierArgument<"Signature">];
4742+
let Args = [IdentifierArgument<"SignatureIdent">,
4743+
DeclArgument<HLSLRootSignature, "SignatureDecl", 0, /*fake=*/1>];
47434744
let Subjects = SubjectList<[Function],
47444745
ErrorDiag, "'function'">;
47454746
let LangOpts = [HLSL];

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@ void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
6868
DXILValMD->addOperand(Val);
6969
}
7070

71+
void addRootSignature(ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
72+
llvm::Function *Fn, llvm::Module &M) {
73+
auto &Ctx = M.getContext();
74+
75+
llvm::hlsl::rootsig::MetadataBuilder Builder(Ctx, Elements);
76+
MDNode *RootSignature = Builder.BuildRootSignature();
77+
MDNode *FnPairing =
78+
MDNode::get(Ctx, {ValueAsMetadata::get(Fn), RootSignature});
79+
80+
StringRef RootSignatureValKey = "dx.rootsignatures";
81+
auto *RootSignatureValMD = M.getOrInsertNamedMetadata(RootSignatureValKey);
82+
RootSignatureValMD->addOperand(FnPairing);
83+
}
84+
7185
} // namespace
7286

7387
llvm::Type *
@@ -423,6 +437,13 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
423437
// FIXME: Handle codegen for return type semantics.
424438
// See: https://github.com/llvm/llvm-project/issues/57875
425439
B.CreateRetVoid();
440+
441+
// Add and identify root signature to function, if applicable
442+
for (const Attr *Attr : FD->getAttrs()) {
443+
if (const auto *RSAttr = dyn_cast<RootSignatureAttr>(Attr))
444+
addRootSignature(RSAttr->getSignatureDecl()->getRootElements(), EntryFn,
445+
M);
446+
}
426447
}
427448

428449
void CGHLSLRuntime::setHLSLFunctionAttributes(const FunctionDecl *FD,

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
959959

960960
IdentifierInfo *Ident = AL.getArgAsIdent(0)->getIdentifierInfo();
961961
if (auto *RS = D->getAttr<RootSignatureAttr>()) {
962-
if (RS->getSignature() != Ident) {
962+
if (RS->getSignatureIdent() != Ident) {
963963
Diag(AL.getLoc(), diag::err_disallowed_duplicate_attribute) << RS;
964964
return;
965965
}
@@ -970,10 +970,11 @@ void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
970970

971971
LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
972972
if (SemaRef.LookupQualifiedName(R, D->getDeclContext()))
973-
if (isa<HLSLRootSignatureDecl>(R.getFoundDecl())) {
973+
if (auto *SignatureDecl =
974+
dyn_cast<HLSLRootSignatureDecl>(R.getFoundDecl())) {
974975
// Perform validation of constructs here
975-
D->addAttr(::new (getASTContext())
976-
RootSignatureAttr(getASTContext(), AL, Ident));
976+
D->addAttr(::new (getASTContext()) RootSignatureAttr(
977+
getASTContext(), AL, Ident, SignatureDecl));
977978
}
978979
}
979980

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -o - %s | FileCheck %s
2+
3+
// CHECK: !dx.rootsignatures = !{![[#FIRST_ENTRY:]], ![[#SECOND_ENTRY:]]}
4+
5+
// CHECK: ![[#FIRST_ENTRY]] = !{ptr @FirstEntry, ![[#EMPTY:]]}
6+
// CHECK: ![[#EMPTY]] = !{}
7+
8+
[shader("compute"), RootSignature("")]
9+
[numthreads(1,1,1)]
10+
void FirstEntry() {}
11+
12+
// CHECK: ![[#SECOND_ENTRY]] = !{ptr @SecondEntry, ![[#SECOND_RS:]]}
13+
// CHECK: ![[#SECOND_RS]] = !{![[#TABLE:]]}
14+
// CHECK: ![[#TABLE]] = !{!"DescriptorTable", i32 0, ![[#CBV:]], ![[#SRV:]]}
15+
// CHECK: ![[#CBV]] = !{!"CBV", i32 1, i32 0, i32 0, i32 -1, i32 4}
16+
// CHECK: ![[#SRV]] = !{!"SRV", i32 4, i32 42, i32 3, i32 32, i32 0}
17+
18+
#define SampleDescriptorTable \
19+
"DescriptorTable( " \
20+
" CBV(b0), " \
21+
" SRV(t42, space = 3, offset = 32, numDescriptors = 4, flags = 0) " \
22+
")"
23+
[shader("compute"), RootSignature(SampleDescriptorTable)]
24+
[numthreads(1,1,1)]
25+
void SecondEntry() {}
26+
27+
// Sanity test to ensure no root is added for this function as there is only
28+
// two entries in !dx.roosignatures
29+
[shader("compute")]
30+
[numthreads(1,1,1)]
31+
void ThirdEntry() {}

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include <variant>
2121

2222
namespace llvm {
23+
class LLVMContext;
24+
class MDNode;
25+
class Metadata;
26+
2327
namespace hlsl {
2428
namespace rootsig {
2529

@@ -84,7 +88,9 @@ struct RootConstants {
8488
// Models the end of a descriptor table and stores its visibility
8589
struct DescriptorTable {
8690
ShaderVisibility Visibility = ShaderVisibility::All;
87-
uint32_t NumClauses = 0; // The number of clauses in the table
91+
// Denotes that the previous NumClauses in the RootElement array
92+
// are the clauses in the table.
93+
uint32_t NumClauses = 0;
8894

8995
void dump(raw_ostream &OS) const;
9096
};
@@ -119,12 +125,47 @@ struct DescriptorTableClause {
119125
void dump(raw_ostream &OS) const;
120126
};
121127

122-
// Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause
128+
/// Models RootElement : RootFlags | RootConstants | DescriptorTable
129+
/// | DescriptorTableClause
130+
///
131+
/// A Root Signature is modeled in-memory by an array of RootElements. These
132+
/// aim to map closely to their DSL grammar reprsentation defined in the spec.
133+
///
134+
/// Each optional parameter has its default value defined in the struct, and,
135+
/// each mandatory parameter does not have a default initialization.
136+
///
137+
/// For the variants RootFlags, RootConstants and DescriptorTableClause: each
138+
/// data member maps directly to a parameter in the grammar.
139+
///
140+
/// The DescriptorTable is modelled by having its Clauses as the previous
141+
/// RootElements in the array, and it holds a data member for the Visibility
142+
/// parameter.
123143
using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
124144
DescriptorTableClause>;
125145

126146
void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements);
127147

148+
class MetadataBuilder {
149+
public:
150+
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
151+
: Ctx(Ctx), Elements(Elements) {}
152+
153+
/// Iterates through the elements and dispatches onto the correct Build method
154+
///
155+
/// Accumulates the root signature and returns the Metadata node that is just
156+
/// a list of all the elements
157+
MDNode *BuildRootSignature();
158+
159+
private:
160+
/// Define the various builders for the different metadata types
161+
MDNode *BuildDescriptorTable(const DescriptorTable &Table);
162+
MDNode *BuildDescriptorTableClause(const DescriptorTableClause &Clause);
163+
164+
llvm::LLVMContext &Ctx;
165+
ArrayRef<RootElement> Elements;
166+
SmallVector<Metadata *> GeneratedMetadata;
167+
};
168+
128169
} // namespace rootsig
129170
} // namespace hlsl
130171
} // namespace llvm

llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
1414
#include "llvm/ADT/bit.h"
15+
#include "llvm/IR/IRBuilder.h"
16+
#include "llvm/IR/Metadata.h"
17+
#include "llvm/IR/Module.h"
1518

1619
namespace llvm {
1720
namespace hlsl {
@@ -160,6 +163,65 @@ void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
160163
OS << "}";
161164
}
162165

166+
MDNode *MetadataBuilder::BuildRootSignature() {
167+
for (const RootElement &Element : Elements) {
168+
MDNode *ElementMD = nullptr;
169+
if (const auto &Clause = std::get_if<DescriptorTableClause>(&Element))
170+
ElementMD = BuildDescriptorTableClause(*Clause);
171+
if (const auto &Table = std::get_if<DescriptorTable>(&Element))
172+
ElementMD = BuildDescriptorTable(*Table);
173+
174+
// FIXME(#126586): remove once all RootElemnt variants are handled in a
175+
// visit or otherwise
176+
assert(ElementMD != nullptr &&
177+
"Constructed an unhandled root element type.");
178+
179+
GeneratedMetadata.push_back(ElementMD);
180+
}
181+
182+
return MDNode::get(Ctx, GeneratedMetadata);
183+
}
184+
185+
MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
186+
IRBuilder<> Builder(Ctx);
187+
SmallVector<Metadata *> TableOperands;
188+
// Set the mandatory arguments
189+
TableOperands.push_back(MDString::get(Ctx, "DescriptorTable"));
190+
TableOperands.push_back(ConstantAsMetadata::get(
191+
Builder.getInt32(llvm::to_underlying(Table.Visibility))));
192+
193+
// Remaining operands are references to the table's clauses. The in-memory
194+
// representation of the Root Elements created from parsing will ensure that
195+
// the previous N elements are the clauses for this table.
196+
assert(Table.NumClauses <= GeneratedMetadata.size() &&
197+
"Table expected all owned clauses to be generated already");
198+
// So, add a refence to each clause to our operands
199+
TableOperands.append(GeneratedMetadata.end() - Table.NumClauses,
200+
GeneratedMetadata.end());
201+
// Then, remove those clauses from the general list of Root Elements
202+
GeneratedMetadata.pop_back_n(Table.NumClauses);
203+
204+
return MDNode::get(Ctx, TableOperands);
205+
}
206+
207+
MDNode *MetadataBuilder::BuildDescriptorTableClause(
208+
const DescriptorTableClause &Clause) {
209+
IRBuilder<> Builder(Ctx);
210+
std::string Name;
211+
llvm::raw_string_ostream OS(Name);
212+
OS << Clause.Type;
213+
return MDNode::get(
214+
Ctx, {
215+
MDString::get(Ctx, OS.str()),
216+
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
217+
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
218+
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
219+
ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
220+
ConstantAsMetadata::get(
221+
Builder.getInt32(llvm::to_underlying(Clause.Flags))),
222+
});
223+
}
224+
163225
} // namespace rootsig
164226
} // namespace hlsl
165227
} // namespace llvm

0 commit comments

Comments
 (0)