Skip to content

Commit eb2929d

Browse files
authored
[DirectX] use DXILMetadataAnalysis to build PSVRuntimeInfo (llvm#107101)
Replace the hardcoded values for compute shader in DXContainer::addPipelineStateValidationInfo. Still missing wave size. Add preserved for previous passes so the information is not lost. Fix llvm/wg-hlsl#51
1 parent b2048de commit eb2929d

File tree

4 files changed

+67
-7
lines changed

4 files changed

+67
-7
lines changed

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llvm/ADT/SmallVector.h"
1616
#include "llvm/ADT/StringExtras.h"
1717
#include "llvm/ADT/StringRef.h"
18+
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1819
#include "llvm/BinaryFormat/DXContainer.h"
1920
#include "llvm/CodeGen/Passes.h"
2021
#include "llvm/IR/Constants.h"
@@ -57,6 +58,7 @@ class DXContainerGlobals : public llvm::ModulePass {
5758
void getAnalysisUsage(AnalysisUsage &AU) const override {
5859
AU.setPreservesAll();
5960
AU.addRequired<ShaderFlagsAnalysisWrapper>();
61+
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
6062
}
6163
};
6264

@@ -143,23 +145,35 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
143145
SmallString<256> Data;
144146
raw_svector_ostream OS(Data);
145147
PSVRuntimeInfo PSV;
146-
Triple TT(M.getTargetTriple());
147148
PSV.BaseData.MinimumWaveLaneCount = 0;
148149
PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max();
150+
151+
dxil::ModuleMetadataInfo &MMI =
152+
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
153+
assert(MMI.EntryPropertyVec.size() == 1 ||
154+
MMI.ShaderStage == Triple::Library);
149155
PSV.BaseData.ShaderStage =
150-
static_cast<uint8_t>(TT.getEnvironment() - Triple::Pixel);
156+
static_cast<uint8_t>(MMI.ShaderStage - Triple::Pixel);
151157

152158
// Hardcoded values here to unblock loading the shader into D3D.
153159
//
154160
// TODO: Lots more stuff to do here!
155161
//
156162
// See issue https://github.com/llvm/llvm-project/issues/96674.
157-
PSV.BaseData.NumThreadsX = 1;
158-
PSV.BaseData.NumThreadsY = 1;
159-
PSV.BaseData.NumThreadsZ = 1;
160-
PSV.EntryName = "main";
163+
switch (MMI.ShaderStage) {
164+
case Triple::Compute:
165+
PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
166+
PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
167+
PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
168+
break;
169+
default:
170+
break;
171+
}
172+
173+
if (MMI.ShaderStage != Triple::Library)
174+
PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
161175

162-
PSV.finalize(TT.getEnvironment());
176+
PSV.finalize(MMI.ShaderStage);
163177
PSV.write(OS);
164178
Constant *Constant =
165179
ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
@@ -170,6 +184,7 @@ char DXContainerGlobals::ID = 0;
170184
INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
171185
"DXContainer Global Emitter", false, true)
172186
INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
187+
INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
173188
INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
174189
"DXContainer Global Emitter", false, true)
175190

llvm/lib/Target/DirectX/DXILPrepare.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/ADT/STLExtras.h"
2020
#include "llvm/ADT/SmallVector.h"
2121
#include "llvm/ADT/StringSet.h"
22+
#include "llvm/Analysis/DXILMetadataAnalysis.h"
2223
#include "llvm/CodeGen/Passes.h"
2324
#include "llvm/IR/AttributeMask.h"
2425
#include "llvm/IR/IRBuilder.h"
@@ -247,6 +248,7 @@ class DXILPrepareModule : public ModulePass {
247248
void getAnalysisUsage(AnalysisUsage &AU) const override {
248249
AU.addPreserved<ShaderFlagsAnalysisWrapper>();
249250
AU.addPreserved<DXILResourceMDWrapper>();
251+
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
250252
}
251253
static char ID; // Pass identification.
252254
};

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "DXILShaderFlags.h"
1414
#include "DirectX.h"
1515
#include "llvm/ADT/StringSet.h"
16+
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1617
#include "llvm/Analysis/DXILResource.h"
1718
#include "llvm/IR/Constants.h"
1819
#include "llvm/IR/Metadata.h"
@@ -103,6 +104,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
103104
AU.addRequired<DXILResourceWrapperPass>();
104105
AU.addRequired<DXILResourceMDWrapper>();
105106
AU.addRequired<ShaderFlagsAnalysisWrapper>();
107+
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
106108
}
107109

108110
bool runOnModule(Module &M) override {
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
; RUN: opt %s -dxil-embed -dxil-globals -S -o - | FileCheck %s
2+
; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
3+
target triple = "dxil-unknown-shadermodel6.0-compute"
4+
5+
; CHECK: @dx.psv0 = private constant [80 x i8] c"{{.*}}", section "PSV0", align 4
6+
7+
define void @cs_main() #0 {
8+
entry:
9+
ret void
10+
}
11+
12+
attributes #0 = { "hlsl.numthreads"="8,8,1" "hlsl.shader"="compute" }
13+
14+
!dx.valver = !{!0}
15+
16+
!0 = !{i32 1, i32 7}
17+
18+
; DXC: - Name: PSV0
19+
; DXC-NEXT: Size: 80
20+
; DXC-NEXT: PSVInfo:
21+
; DXC-NEXT: Version: 3
22+
; DXC-NEXT: ShaderStage: 5
23+
; DXC-NEXT: MinimumWaveLaneCount: 0
24+
; DXC-NEXT: MaximumWaveLaneCount: 4294967295
25+
; DXC-NEXT: UsesViewID: 0
26+
; DXC-NEXT: SigInputVectors: 0
27+
; DXC-NEXT: SigOutputVectors: [ 0, 0, 0, 0 ]
28+
; DXC-NEXT: NumThreadsX: 8
29+
; DXC-NEXT: NumThreadsY: 8
30+
; DXC-NEXT: NumThreadsZ: 1
31+
; DXC-NEXT: EntryName: cs_main
32+
; DXC-NEXT: ResourceStride: 24
33+
; DXC-NEXT: Resources: []
34+
; DXC-NEXT: SigInputElements: []
35+
; DXC-NEXT: SigOutputElements: []
36+
; DXC-NEXT: SigPatchOrPrimElements: []
37+
; DXC-NEXT: InputOutputMap:
38+
; DXC-NEXT: - [ ]
39+
; DXC-NEXT: - [ ]
40+
; DXC-NEXT: - [ ]
41+
; DXC-NEXT: - [ ]

0 commit comments

Comments
 (0)