@@ -174,6 +174,93 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
174
174
return false ;
175
175
}
176
176
177
+ static bool parseDescriptorRange (LLVMContext *Ctx,
178
+ mcdxbc::DescriptorTable &Table,
179
+ MDNode *RangeDescriptorNode) {
180
+
181
+ if (RangeDescriptorNode->getNumOperands () != 6 )
182
+ return reportError (Ctx, " Invalid format for Descriptor Range" );
183
+
184
+ dxbc::RTS0::v2::DescriptorRange Range;
185
+
186
+ std::optional<StringRef> ElementText =
187
+ extractMdStringValue (RangeDescriptorNode, 0 );
188
+
189
+ if (!ElementText.has_value ())
190
+ return reportError (Ctx, " Descriptor Range, first element is not a string." );
191
+
192
+ Range.RangeType =
193
+ StringSwitch<uint32_t >(*ElementText)
194
+ .Case (" CBV" , llvm::to_underlying (dxbc::DescriptorRangeType::CBV))
195
+ .Case (" SRV" , llvm::to_underlying (dxbc::DescriptorRangeType::SRV))
196
+ .Case (" UAV" , llvm::to_underlying (dxbc::DescriptorRangeType::UAV))
197
+ .Case (" Sampler" ,
198
+ llvm::to_underlying (dxbc::DescriptorRangeType::Sampler))
199
+ .Default (~0U );
200
+
201
+ if (Range.RangeType == ~0U )
202
+ return reportError (Ctx, " Invalid Descriptor Range type: " + *ElementText);
203
+
204
+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 1 ))
205
+ Range.NumDescriptors = *Val;
206
+ else
207
+ return reportError (Ctx, " Invalid value for Number of Descriptor in Range" );
208
+
209
+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 2 ))
210
+ Range.BaseShaderRegister = *Val;
211
+ else
212
+ return reportError (Ctx, " Invalid value for BaseShaderRegister" );
213
+
214
+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 3 ))
215
+ Range.RegisterSpace = *Val;
216
+ else
217
+ return reportError (Ctx, " Invalid value for RegisterSpace" );
218
+
219
+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 4 ))
220
+ Range.OffsetInDescriptorsFromTableStart = *Val;
221
+ else
222
+ return reportError (Ctx,
223
+ " Invalid value for OffsetInDescriptorsFromTableStart" );
224
+
225
+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 5 ))
226
+ Range.Flags = *Val;
227
+ else
228
+ return reportError (Ctx, " Invalid value for Descriptor Range Flags" );
229
+
230
+ Table.Ranges .push_back (Range);
231
+ return false ;
232
+ }
233
+
234
+ static bool parseDescriptorTable (LLVMContext *Ctx,
235
+ mcdxbc::RootSignatureDesc &RSD,
236
+ MDNode *DescriptorTableNode) {
237
+ const unsigned int NumOperands = DescriptorTableNode->getNumOperands ();
238
+ if (NumOperands < 2 )
239
+ return reportError (Ctx, " Invalid format for Descriptor Table" );
240
+
241
+ dxbc::RTS0::v1::RootParameterHeader Header;
242
+ if (std::optional<uint32_t > Val = extractMdIntValue (DescriptorTableNode, 1 ))
243
+ Header.ShaderVisibility = *Val;
244
+ else
245
+ return reportError (Ctx, " Invalid value for ShaderVisibility" );
246
+
247
+ mcdxbc::DescriptorTable Table;
248
+ Header.ParameterType =
249
+ llvm::to_underlying (dxbc::RootParameterType::DescriptorTable);
250
+
251
+ for (unsigned int I = 2 ; I < NumOperands; I++) {
252
+ MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand (I));
253
+ if (Element == nullptr )
254
+ return reportError (Ctx, " Missing Root Element Metadata Node." );
255
+
256
+ if (parseDescriptorRange (Ctx, Table, Element))
257
+ return true ;
258
+ }
259
+
260
+ RSD.ParametersContainer .addParameter (Header, Table);
261
+ return false ;
262
+ }
263
+
177
264
static bool parseRootSignatureElement (LLVMContext *Ctx,
178
265
mcdxbc::RootSignatureDesc &RSD,
179
266
MDNode *Element) {
@@ -188,6 +275,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
188
275
.Case (" RootCBV" , RootSignatureElementKind::CBV)
189
276
.Case (" RootSRV" , RootSignatureElementKind::SRV)
190
277
.Case (" RootUAV" , RootSignatureElementKind::UAV)
278
+ .Case (" DescriptorTable" , RootSignatureElementKind::DescriptorTable)
191
279
.Default (RootSignatureElementKind::Error);
192
280
193
281
switch (ElementKind) {
@@ -200,6 +288,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
200
288
case RootSignatureElementKind::SRV:
201
289
case RootSignatureElementKind::UAV:
202
290
return parseRootDescriptors (Ctx, RSD, Element, ElementKind);
291
+ case RootSignatureElementKind::DescriptorTable:
292
+ return parseDescriptorTable (Ctx, RSD, Element);
203
293
case RootSignatureElementKind::Error:
204
294
return reportError (Ctx, " Invalid Root Signature Element: " + *ElementText);
205
295
}
@@ -241,6 +331,81 @@ static bool verifyRegisterSpace(uint32_t RegisterSpace) {
241
331
242
332
static bool verifyDescriptorFlag (uint32_t Flags) { return (Flags & ~0xE ) == 0 ; }
243
333
334
+ static bool verifyRangeType (uint32_t Type) {
335
+ switch (Type) {
336
+ case llvm::to_underlying (dxbc::DescriptorRangeType::CBV):
337
+ case llvm::to_underlying (dxbc::DescriptorRangeType::SRV):
338
+ case llvm::to_underlying (dxbc::DescriptorRangeType::UAV):
339
+ case llvm::to_underlying (dxbc::DescriptorRangeType::Sampler):
340
+ return true ;
341
+ };
342
+
343
+ return false ;
344
+ }
345
+
346
+ static bool verifyDescriptorRangeFlag (uint32_t Version, uint32_t Type,
347
+ uint32_t FlagsVal) {
348
+ using FlagT = dxbc::DescriptorRangeFlag;
349
+ FlagT Flags = FlagT (FlagsVal);
350
+
351
+ const bool IsSampler =
352
+ (Type == llvm::to_underlying (dxbc::DescriptorRangeType::Sampler));
353
+
354
+ if (Version == 1 ) {
355
+ // Since the metadata is unversioned, we expect to explicitly see the values
356
+ // that map to the version 1 behaviour here.
357
+ if (IsSampler)
358
+ return Flags == FlagT::DESCRIPTORS_VOLATILE;
359
+ return Flags == (FlagT::DATA_VOLATILE | FlagT::DESCRIPTORS_VOLATILE);
360
+ }
361
+
362
+ // The data-specific flags are mutually exclusive.
363
+ FlagT DataFlags = FlagT::DATA_VOLATILE | FlagT::DATA_STATIC |
364
+ FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
365
+
366
+ if (popcount (llvm::to_underlying (Flags & DataFlags)) > 1 )
367
+ return false ;
368
+
369
+ // The descriptor-specific flags are mutually exclusive.
370
+ FlagT DescriptorFlags =
371
+ FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS |
372
+ FlagT::DESCRIPTORS_VOLATILE;
373
+ if (popcount (llvm::to_underlying (Flags & DescriptorFlags)) > 1 )
374
+ return false ;
375
+
376
+ // For volatile descriptors, DATA_STATIC is never valid.
377
+ if ((Flags & FlagT::DESCRIPTORS_VOLATILE) == FlagT::DESCRIPTORS_VOLATILE) {
378
+ FlagT Mask = FlagT::DESCRIPTORS_VOLATILE;
379
+ if (!IsSampler) {
380
+ Mask |= FlagT::DATA_VOLATILE;
381
+ Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
382
+ }
383
+ return (Flags & ~Mask) == FlagT::NONE;
384
+ }
385
+
386
+ // For "STATIC_KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
387
+ // the other data-specific flags may all be set.
388
+ if ((Flags & FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) ==
389
+ FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) {
390
+ FlagT Mask = FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS;
391
+ if (!IsSampler) {
392
+ Mask |= FlagT::DATA_VOLATILE;
393
+ Mask |= FlagT::DATA_STATIC;
394
+ Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
395
+ }
396
+ return (Flags & ~Mask) == FlagT::NONE;
397
+ }
398
+
399
+ // When no descriptor flag is set, any data flag is allowed.
400
+ FlagT Mask = FlagT::NONE;
401
+ if (!IsSampler) {
402
+ Mask |= FlagT::DATA_VOLATILE;
403
+ Mask |= FlagT::DATA_STATIC;
404
+ Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
405
+ }
406
+ return (Flags & ~Mask) == FlagT::NONE;
407
+ }
408
+
244
409
static bool validate (LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
245
410
246
411
if (!verifyVersion (RSD.Version )) {
@@ -275,7 +440,23 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
275
440
276
441
if (RSD.Version > 1 ) {
277
442
if (!verifyDescriptorFlag (Descriptor.Flags ))
278
- return reportValueError (Ctx, " DescriptorFlag" , Descriptor.Flags );
443
+ return reportValueError (Ctx, " DescriptorRangeFlag" , Descriptor.Flags );
444
+ }
445
+ break ;
446
+ }
447
+ case llvm::to_underlying (dxbc::RootParameterType::DescriptorTable): {
448
+ const mcdxbc::DescriptorTable &Table =
449
+ RSD.ParametersContainer .getDescriptorTable (Info.Location );
450
+ for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
451
+ if (!verifyRangeType (Range.RangeType ))
452
+ return reportValueError (Ctx, " RangeType" , Range.RangeType );
453
+
454
+ if (!verifyRegisterSpace (Range.RegisterSpace ))
455
+ return reportValueError (Ctx, " RegisterSpace" , Range.RegisterSpace );
456
+
457
+ if (!verifyDescriptorRangeFlag (RSD.Version , Range.RangeType ,
458
+ Range.Flags ))
459
+ return reportValueError (Ctx, " DescriptorFlag" , Range.Flags );
279
460
}
280
461
break ;
281
462
}
@@ -388,67 +569,67 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
388
569
389
570
OS << " Root Signature Definitions"
390
571
<< " \n " ;
391
- uint8_t Space = 0 ;
392
572
for (const Function &F : M) {
393
573
auto It = RSDMap.find (&F);
394
574
if (It == RSDMap.end ())
395
575
continue ;
396
576
const auto &RS = It->second ;
397
577
OS << " Definition for '" << F.getName () << " ':\n " ;
398
-
399
578
// start root signature header
400
- Space++;
401
- OS << indent (Space) << " Flags: " << format_hex (RS.Flags , 8 ) << " \n " ;
402
- OS << indent (Space) << " Version: " << RS.Version << " \n " ;
403
- OS << indent (Space) << " RootParametersOffset: " << RS.RootParameterOffset
404
- << " \n " ;
405
- OS << indent (Space) << " NumParameters: " << RS.ParametersContainer .size ()
406
- << " \n " ;
407
- Space++;
579
+ OS << " Flags: " << format_hex (RS.Flags , 8 ) << " \n "
580
+ << " Version: " << RS.Version << " \n "
581
+ << " RootParametersOffset: " << RS.RootParameterOffset << " \n "
582
+ << " NumParameters: " << RS.ParametersContainer .size () << " \n " ;
408
583
for (size_t I = 0 ; I < RS.ParametersContainer .size (); I++) {
409
584
const auto &[Type, Loc] =
410
585
RS.ParametersContainer .getTypeAndLocForParameter (I);
411
586
const dxbc::RTS0::v1::RootParameterHeader Header =
412
587
RS.ParametersContainer .getHeader (I);
413
588
414
- OS << indent (Space) << " - Parameter Type: " << Type << " \n " ;
415
- OS << indent (Space + 2 )
416
- << " Shader Visibility: " << Header.ShaderVisibility << " \n " ;
589
+ OS << " - Parameter Type: " << Type << " \n "
590
+ << " Shader Visibility: " << Header.ShaderVisibility << " \n " ;
417
591
418
592
switch (Type) {
419
593
case llvm::to_underlying (dxbc::RootParameterType::Constants32Bit): {
420
594
const dxbc::RTS0::v1::RootConstants &Constants =
421
595
RS.ParametersContainer .getConstant (Loc);
422
- OS << indent (Space + 2 ) << " Register Space: " << Constants.RegisterSpace
423
- << " \n " ;
424
- OS << indent (Space + 2 )
425
- << " Shader Register: " << Constants.ShaderRegister << " \n " ;
426
- OS << indent (Space + 2 )
427
- << " Num 32 Bit Values: " << Constants.Num32BitValues << " \n " ;
596
+ OS << " Register Space: " << Constants.RegisterSpace << " \n "
597
+ << " Shader Register: " << Constants.ShaderRegister << " \n "
598
+ << " Num 32 Bit Values: " << Constants.Num32BitValues << " \n " ;
428
599
break ;
429
600
}
430
601
case llvm::to_underlying (dxbc::RootParameterType::CBV):
431
602
case llvm::to_underlying (dxbc::RootParameterType::UAV):
432
603
case llvm::to_underlying (dxbc::RootParameterType::SRV): {
433
604
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
434
605
RS.ParametersContainer .getRootDescriptor (Loc);
435
- OS << indent (Space + 2 )
436
- << " Register Space: " << Descriptor.RegisterSpace << " \n " ;
437
- OS << indent (Space + 2 )
438
- << " Shader Register: " << Descriptor.ShaderRegister << " \n " ;
606
+ OS << " Register Space: " << Descriptor.RegisterSpace << " \n "
607
+ << " Shader Register: " << Descriptor.ShaderRegister << " \n " ;
439
608
if (RS.Version > 1 )
440
- OS << indent (Space + 2 ) << " Flags: " << Descriptor.Flags << " \n " ;
609
+ OS << " Flags: " << Descriptor.Flags << " \n " ;
610
+ break ;
611
+ }
612
+ case llvm::to_underlying (dxbc::RootParameterType::DescriptorTable): {
613
+ const mcdxbc::DescriptorTable &Table =
614
+ RS.ParametersContainer .getDescriptorTable (Loc);
615
+ OS << " NumRanges: " << Table.Ranges .size () << " \n " ;
616
+
617
+ for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
618
+ OS << " - Range Type: " << Range.RangeType << " \n "
619
+ << " Register Space: " << Range.RegisterSpace << " \n "
620
+ << " Base Shader Register: " << Range.BaseShaderRegister << " \n "
621
+ << " Num Descriptors: " << Range.NumDescriptors << " \n "
622
+ << " Offset In Descriptors From Table Start: "
623
+ << Range.OffsetInDescriptorsFromTableStart << " \n " ;
624
+ if (RS.Version > 1 )
625
+ OS << " Flags: " << Range.Flags << " \n " ;
626
+ }
441
627
break ;
442
628
}
443
629
}
444
- Space--;
445
630
}
446
- OS << indent (Space) << " NumStaticSamplers: " << 0 << " \n " ;
447
- OS << indent (Space) << " StaticSamplersOffset: " << RS.StaticSamplersOffset
448
- << " \n " ;
449
-
450
- Space--;
451
- // end root signature header
631
+ OS << " NumStaticSamplers: " << 0 << " \n " ;
632
+ OS << " StaticSamplersOffset: " << RS.StaticSamplersOffset << " \n " ;
452
633
}
453
634
return PreservedAnalyses::all ();
454
635
}
0 commit comments