Skip to content

[DXIL] Consume Metadata Analysis information in passes #108034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 23, 2024
Merged
8 changes: 4 additions & 4 deletions llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ class Function;
namespace dxil {

struct EntryProperties {
const Function *Entry;
const Function *Entry{nullptr};
// Specific target shader stage may be specified for entry functions
Triple::EnvironmentType ShaderStage = Triple::UnknownEnvironment;
Triple::EnvironmentType ShaderStage{Triple::UnknownEnvironment};
unsigned NumThreadsX{0}; // X component
unsigned NumThreadsY{0}; // Y component
unsigned NumThreadsZ{0}; // Z component

EntryProperties(const Function &Fn) : Entry(&Fn) {};
EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
};

struct ModuleMetadataInfo {
VersionTuple DXILVersion{};
VersionTuple ShaderModelVersion{};
Triple::EnvironmentType ShaderStage = Triple::UnknownEnvironment;
Triple::EnvironmentType ShaderProfile{Triple::UnknownEnvironment};
VersionTuple ValidatorVersion{};
SmallVector<EntryProperties> EntryPropertyVec{};
void print(raw_ostream &OS) const;
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Analysis/DXILMetadataAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
Triple TT(Triple(M.getTargetTriple()));
MMDAI.DXILVersion = TT.getDXILVersion();
MMDAI.ShaderModelVersion = TT.getOSVersion();
MMDAI.ShaderStage = TT.getEnvironment();
MMDAI.ShaderProfile = TT.getEnvironment();
NamedMDNode *ValidatorVerNode = M.getNamedMetadata("dx.valver");
if (ValidatorVerNode) {
auto *ValVerMD = cast<MDNode>(ValidatorVerNode->getOperand(0));
Expand All @@ -42,7 +42,7 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
if (!F.hasFnAttribute("hlsl.shader"))
continue;

EntryProperties EFP(F);
EntryProperties EFP(&F);
// Get "hlsl.shader" attribute
Attribute EntryAttr = F.getFnAttribute("hlsl.shader");
assert(EntryAttr.isValid() &&
Expand Down Expand Up @@ -74,8 +74,8 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
void ModuleMetadataInfo::print(raw_ostream &OS) const {
OS << "Shader Model Version : " << ShaderModelVersion.getAsString() << "\n";
OS << "DXIL Version : " << DXILVersion.getAsString() << "\n";
OS << "Target Shader Stage : " << Triple::getEnvironmentTypeName(ShaderStage)
<< "\n";
OS << "Target Shader Stage : "
<< Triple::getEnvironmentTypeName(ShaderProfile) << "\n";
OS << "Validator Version : " << ValidatorVersion.getAsString() << "\n";
for (const auto &EP : EntryPropertyVec) {
OS << " " << EP.Entry->getName() << "\n";
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ add_llvm_target(DirectXCodeGen
DXContainerGlobals.cpp
DXILFinalizeLinkage.cpp
DXILIntrinsicExpansion.cpp
DXILMetadata.cpp
DXILOpBuilder.cpp
DXILOpLowering.cpp
DXILPrepare.cpp
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/Target/DirectX/DXContainerGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
dxil::ModuleMetadataInfo &MMI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
assert(MMI.EntryPropertyVec.size() == 1 ||
MMI.ShaderStage == Triple::Library);
MMI.ShaderProfile == Triple::Library);
PSV.BaseData.ShaderStage =
static_cast<uint8_t>(MMI.ShaderStage - Triple::Pixel);
static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel);

addResourcesForPSV(M, PSV);

Expand All @@ -215,7 +215,7 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
// TODO: Lots more stuff to do here!
//
// See issue https://github.com/llvm/llvm-project/issues/96674.
switch (MMI.ShaderStage) {
switch (MMI.ShaderProfile) {
case Triple::Compute:
PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
Expand All @@ -225,10 +225,10 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
break;
}

if (MMI.ShaderStage != Triple::Library)
if (MMI.ShaderProfile != Triple::Library)
PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();

PSV.finalize(MMI.ShaderStage);
PSV.finalize(MMI.ShaderProfile);
PSV.write(OS);
Constant *Constant =
ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
Expand Down
Loading