Skip to content

[DirectX] use DXILMetadataAnalysis to build PSVRuntimeInfo #107101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 5, 2024

Conversation

python3kgae
Copy link
Contributor

Replace the hardcoded values for compute shader in DXContainer::addPipelineStateValidationInfo.
Still missing wave size.

Add preserved for previous passes so the information is not lost.

Fix llvm/wg-hlsl#51

Replace the hardcoded values for compute shader in DXContainer::addPipelineStateValidationInfo.
Still missing wave size.

Add preserved for previous passes so the information is not lost.

Fix llvm/wg-hlsl#51
@llvmbot
Copy link
Member

llvmbot commented Sep 3, 2024

@llvm/pr-subscribers-backend-directx

Author: Xiang Li (python3kgae)

Changes

Replace the hardcoded values for compute shader in DXContainer::addPipelineStateValidationInfo.
Still missing wave size.

Add preserved for previous passes so the information is not lost.

Fix llvm/wg-hlsl#51


Full diff: https://github.com/llvm/llvm-project/pull/107101.diff

4 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+19-4)
  • (modified) llvm/lib/Target/DirectX/DXILPrepare.cpp (+2)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+2)
  • (added) llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll (+41)
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index d47b9c7a25b8fe..1a9e16f102c007 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/Constants.h"
@@ -57,6 +58,7 @@ class DXContainerGlobals : public llvm::ModulePass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
     AU.addRequired<ShaderFlagsAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
   }
 };
 
@@ -149,15 +151,27 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
   PSV.BaseData.ShaderStage =
       static_cast<uint8_t>(TT.getEnvironment() - Triple::Pixel);
 
+  dxil::ModuleMetadataInfo &MMI =
+      getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+  assert(MMI.EntryPropertyVec.size() != 0 ||
+         TT.getEnvironment() == Triple::Library);
   // Hardcoded values here to unblock loading the shader into D3D.
   //
   // TODO: Lots more stuff to do here!
   //
   // See issue https://github.com/llvm/llvm-project/issues/96674.
-  PSV.BaseData.NumThreadsX = 1;
-  PSV.BaseData.NumThreadsY = 1;
-  PSV.BaseData.NumThreadsZ = 1;
-  PSV.EntryName = "main";
+  switch (TT.getEnvironment()) {
+  case Triple::Compute:
+    PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
+    PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
+    PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
+    break;
+  default:
+    break;
+  }
+
+  if (TT.getEnvironment() != Triple::Library)
+    PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
 
   PSV.finalize(TT.getEnvironment());
   PSV.write(OS);
@@ -170,6 +184,7 @@ char DXContainerGlobals::ID = 0;
 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals",
                       "DXContainer Global Emitter", false, true)
 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
+INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals",
                     "DXContainer Global Emitter", false, true)
 
diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index 56098864e987fb..f6b7355b936255 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/AttributeMask.h"
 #include "llvm/IR/IRBuilder.h"
@@ -247,6 +248,7 @@ class DXILPrepareModule : public ModulePass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
     AU.addPreserved<DXILResourceMDWrapper>();
+    AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
   }
   static char ID; // Pass identification.
 };
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 2c6d20112060df..701dbc1353dab9 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -14,6 +14,7 @@
 #include "DirectX.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Analysis/DXILResource.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
@@ -103,6 +104,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
     AU.addRequired<DXILResourceWrapperPass>();
     AU.addRequired<DXILResourceMDWrapper>();
     AU.addRequired<ShaderFlagsAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
   }
 
   bool runOnModule(Module &M) override {
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll b/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll
new file mode 100644
index 00000000000000..595e70092bb081
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll
@@ -0,0 +1,41 @@
+; RUN: opt %s -dxil-embed -dxil-globals -S -o - | FileCheck %s
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: @dx.psv0 = private constant [80 x i8] c"{{.*}}", section "PSV0", align 4
+
+define void @cs_main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="8,8,1" "hlsl.shader"="compute" }
+
+!dx.valver = !{!0}
+
+!0 = !{i32 1, i32 7}
+
+; DXC: - Name:            PSV0
+; DXC-NEXT:   Size:            80
+; DXC-NEXT:    PSVInfo:
+; DXC-NEXT:      Version:         3
+; DXC-NEXT:      ShaderStage:     5
+; DXC-NEXT:      MinimumWaveLaneCount: 0
+; DXC-NEXT:      MaximumWaveLaneCount: 4294967295
+; DXC-NEXT:      UsesViewID:      0
+; DXC-NEXT:      SigInputVectors: 0
+; DXC-NEXT:      SigOutputVectors: [ 0, 0, 0, 0 ]
+; DXC-NEXT:      NumThreadsX:     8
+; DXC-NEXT:      NumThreadsY:     8
+; DXC-NEXT:      NumThreadsZ:     1
+; DXC-NEXT:      EntryName:       cs_main
+; DXC-NEXT:      ResourceStride:  24
+; DXC-NEXT:      Resources:       []
+; DXC-NEXT:      SigInputElements: []
+; DXC-NEXT:      SigOutputElements: []
+; DXC-NEXT:      SigPatchOrPrimElements: []
+; DXC-NEXT:      InputOutputMap:
+; DXC-NEXT:        - [  ]
+; DXC-NEXT:        - [  ]
+; DXC-NEXT:        - [  ]
+; DXC-NEXT:        - [  ]

Copy link

github-actions bot commented Sep 3, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

dxil::ModuleMetadataInfo &MMI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
assert(MMI.EntryPropertyVec.size() != 0 ||
TT.getEnvironment() == Triple::Library);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the shader stage from MMI viz., MMI.ShaderStage be used instead of once again getting the triple from module and getting the shader stage from the triple throughout this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

Would changes to the rest of the function to use MMI.ShaderStage as follows, also be appropriate?

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index c6998283850f..4b59251cfd90 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -145,7 +145,6 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
   SmallString<256> Data;
   raw_svector_ostream OS(Data);
   PSVRuntimeInfo PSV;
-  Triple TT(M.getTargetTriple());
   PSV.BaseData.MinimumWaveLaneCount = 0;
   PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max();
 
@@ -161,7 +160,7 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
   // TODO: Lots more stuff to do here!
   //
   // See issue https://github.com/llvm/llvm-project/issues/96674.
-  switch (TT.getEnvironment()) {
+  switch (MMI.ShaderStage) {
   case Triple::Compute:
     PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
     PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
@@ -171,10 +170,10 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
     break;
   }
 
-  if (TT.getEnvironment() != Triple::Library)
+  if (MMI.ShaderStage != Triple::Library)
     PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
 
-  PSV.finalize(TT.getEnvironment());
+  PSV.finalize(MMI.ShaderStage);
   PSV.write(OS);
   Constant *Constant =
       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.
Updated.

@@ -103,6 +104,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
AU.addRequired<DXILResourceWrapperPass>();
AU.addRequired<DXILResourceMDWrapper>();
AU.addRequired<ShaderFlagsAnalysisWrapper>();
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
Copy link
Contributor

@bharadwajy bharadwajy Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does DXILMetadataAnalysisWrapperPass need to be added as required for DXILTranslateMetadata pass, given the changes in this PR do not consume results from DXILMetadataAnalysis pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The analysis needs to be run before preparePass which removes the attributes.
And we could consume result for translateMetadata pass as well.

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine, but I think as we add more data to the metadata analysis the need to know its details in order to use its results will become untenable. It's probably reasonable to get this in as is and fix that in-tree though.

Comment on lines 153 to 154
assert(MMI.EntryPropertyVec.size() != 0 ||
MMI.ShaderStage == Triple::Library);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What stops us from getting here with zero entry points? I guess we'll have diagnosed that and errored out earlier? Similarly, what if there are multiple entry points? Is that only possible for library shaders?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Should be size == 1.
Multiple entry points or no entry point should be only possible for library shaders.

Comment on lines +163 to +174
switch (MMI.ShaderStage) {
case Triple::Compute:
PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
break;
default:
break;
}

if (MMI.ShaderStage != Triple::Library)
PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems pretty unfortunate that we need to know which fields are valid in the analysis result based on what stage it is - I see you made a similar comment on the metadata analysis PR, but I didn't notice it at the time. We should probably have a discussion on what the API of the metadata analysis should look like (cc @bharadwajy).

@python3kgae python3kgae merged commit eb2929d into llvm:main Sep 5, 2024
8 checks passed
@python3kgae python3kgae deleted the psv_runtime_info_cs branch September 5, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Entry properties for Compute shader in PSV0 part
4 participants