Skip to content

Commit 6d0a795

Browse files
committed
Add DXILVersion for DXIL metadata abstractions
1 parent 2b13ac2 commit 6d0a795

File tree

9 files changed

+240
-48
lines changed

9 files changed

+240
-48
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===- DXIL.h - Abstractions for DXIL constructs ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file This file defines various abstractions for transforming between DXIL's
10+
// and LLVM's representations of shader metadata.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_TRANSFORMS_UTILS_DXIL_H
15+
#define LLVM_TRANSFORMS_UTILS_DXIL_H
16+
17+
#include "llvm/Support/Error.h"
18+
#include "llvm/TargetParser/Triple.h"
19+
20+
namespace llvm {
21+
class Module;
22+
23+
namespace dxil {
24+
25+
class DXILVersion {
26+
unsigned Major = 0;
27+
unsigned Minor = 0;
28+
29+
public:
30+
DXILVersion() = default;
31+
DXILVersion(unsigned Major, unsigned Minor) : Major(Major), Minor(Minor) {}
32+
33+
/// Get the DXILVersion for \c M
34+
static Expected<DXILVersion> get(Module &M);
35+
/// Read the DXILVersion from the DXIL metadata in \c M
36+
static Expected<DXILVersion> readDXIL(Module &M);
37+
38+
/// Returns true if no DXILVersion is set
39+
bool empty() { return Major == 0 && Minor == 0; }
40+
41+
/// Remove any non-DXIL LLVM representations of the DXILVersion from \c M.
42+
void strip(Module &M);
43+
/// Embed the LLVM representation of the DXILVersion into \c M.
44+
void embed(Module &M);
45+
/// Remove any DXIL representation of the DXILVersion from \c M.
46+
void stripDXIL(Module &M);
47+
/// Embed a DXIL representation of the DXILVersion into \c M.
48+
void embedDXIL(Module &M);
49+
50+
void print(raw_ostream &OS) const {
51+
// Format like Triple ArchName.
52+
OS << "dxilv" << Major << "." << Minor;
53+
}
54+
LLVM_DUMP_METHOD void dump() const { print(errs()); }
55+
};
56+
57+
inline raw_ostream &operator<<(raw_ostream &OS, const DXILVersion &V) {
58+
V.print(OS);
59+
return OS;
60+
}
61+
62+
} // namespace dxil
63+
} // namespace llvm
64+
65+
#endif // LLVM_TRANSFORMS_UTILS_DXIL_H
66+

llvm/lib/Target/DirectX/DXILMetadata.cpp

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -81,52 +81,6 @@ void dxil::createShaderModelMD(Module &M) {
8181
Entry->addOperand(MDNode::get(Ctx, Vals));
8282
}
8383

84-
void dxil::createDXILVersionMD(Module &M) {
85-
Triple TT(Triple::normalize(M.getTargetTriple()));
86-
VersionTuple Ver = VersionTuple(1, 0);
87-
switch (TT.getSubArch()) {
88-
case Triple::DXILSubArch_v1_0:
89-
Ver = VersionTuple(1, 0);
90-
break;
91-
case Triple::DXILSubArch_v1_1:
92-
Ver = VersionTuple(1, 1);
93-
break;
94-
case Triple::DXILSubArch_v1_2:
95-
Ver = VersionTuple(1, 2);
96-
break;
97-
case Triple::DXILSubArch_v1_3:
98-
Ver = VersionTuple(1, 3);
99-
break;
100-
case Triple::DXILSubArch_v1_4:
101-
Ver = VersionTuple(1, 4);
102-
break;
103-
case Triple::DXILSubArch_v1_5:
104-
Ver = VersionTuple(1, 5);
105-
break;
106-
case Triple::DXILSubArch_v1_6:
107-
Ver = VersionTuple(1, 6);
108-
break;
109-
case Triple::DXILSubArch_v1_7:
110-
Ver = VersionTuple(1, 7);
111-
break;
112-
case Triple::DXILSubArch_v1_8:
113-
Ver = VersionTuple(1, 8);
114-
break;
115-
case Triple::NoSubArch:
116-
break;
117-
default:
118-
llvm_unreachable("Unsupported subarch for DXIL generation.");
119-
break;
120-
}
121-
LLVMContext &Ctx = M.getContext();
122-
IRBuilder<> B(Ctx);
123-
NamedMDNode *Entry = M.getOrInsertNamedMetadata("dx.version");
124-
Metadata *Vals[2];
125-
Vals[0] = ConstantAsMetadata::get(B.getInt32(Ver.getMajor()));
126-
Vals[1] = ConstantAsMetadata::get(B.getInt32(Ver.getMinor().value_or(0)));
127-
Entry->addOperand(MDNode::get(Ctx, Vals));
128-
}
129-
13084
static uint32_t getShaderStage(Triple::EnvironmentType Env) {
13185
return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
13286
}

llvm/lib/Target/DirectX/DXILMetadata.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class ValidatorVersionMD {
3333
};
3434

3535
void createShaderModelMD(Module &M);
36-
void createDXILVersionMD(Module &M);
3736
void createEntryMD(Module &M, const uint64_t ShaderFlags);
3837

3938
} // namespace dxil

llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/IR/Module.h"
2020
#include "llvm/Pass.h"
2121
#include "llvm/TargetParser/Triple.h"
22+
#include "llvm/Transforms/Utils/DXIL.h"
2223

2324
using namespace llvm;
2425
using namespace llvm::dxil;
@@ -48,7 +49,13 @@ bool DXILTranslateMetadata::runOnModule(Module &M) {
4849
if (ValVerMD.isEmpty())
4950
ValVerMD.update(VersionTuple(1, 0));
5051
dxil::createShaderModelMD(M);
51-
dxil::createDXILVersionMD(M);
52+
Expected<dxil::DXILVersion> DXILVer = dxil::DXILVersion::get(M);
53+
if (auto E = DXILVer.takeError()) {
54+
errs() << "Fail to get DXIL version " << toString(std::move(E)) << "\n";
55+
return false;
56+
}
57+
58+
DXILVer->embedDXIL(M);
5259

5360
const dxil::Resources &Res =
5461
getAnalysis<DXILResourceWrapper>().getDXILResource();

llvm/lib/Transforms/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_llvm_component_library(LLVMTransformUtils
2020
CountVisits.cpp
2121
Debugify.cpp
2222
DemoteRegToStack.cpp
23+
DXIL.cpp
2324
DXILUpgrade.cpp
2425
EntryExitInstrumenter.cpp
2526
EscapeEnumerator.cpp

llvm/lib/Transforms/Utils/DXIL.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
//===- DXIL.cpp - Abstractions for DXIL constructs ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Transforms/Utils/DXIL.h"
10+
#include "llvm/ADT/SmallString.h"
11+
#include "llvm/ADT/StringSwitch.h"
12+
#include "llvm/IR/Constants.h"
13+
#include "llvm/IR/IRBuilder.h"
14+
#include "llvm/IR/Metadata.h"
15+
#include "llvm/IR/Module.h"
16+
17+
using namespace llvm;
18+
using namespace llvm::dxil;
19+
20+
static Error errInvalid(const char *Msg) {
21+
return createStringError(std::errc::invalid_argument, Msg);
22+
}
23+
24+
template <typename... Ts>
25+
static Error errInvalid(const char *Fmt, const Ts &... Vals) {
26+
return createStringError(std::errc::invalid_argument, Fmt, Vals...);
27+
}
28+
29+
Expected<DXILVersion> DXILVersion::get(Module &M) {
30+
Triple TT(Triple::normalize(M.getTargetTriple()));
31+
32+
if (!TT.isDXIL())
33+
return errInvalid("Cannot get DXIL version for arch '%s'",
34+
TT.getArchName().str().c_str());
35+
36+
switch (TT.getSubArch()) {
37+
case Triple::NoSubArch:
38+
case Triple::DXILSubArch_v1_0:
39+
return DXILVersion(1, 0);
40+
case Triple::DXILSubArch_v1_1:
41+
return DXILVersion(1, 1);
42+
case Triple::DXILSubArch_v1_2:
43+
return DXILVersion(1, 2);
44+
case Triple::DXILSubArch_v1_3:
45+
return DXILVersion(1, 3);
46+
case Triple::DXILSubArch_v1_4:
47+
return DXILVersion(1, 4);
48+
case Triple::DXILSubArch_v1_5:
49+
return DXILVersion(1, 5);
50+
case Triple::DXILSubArch_v1_6:
51+
return DXILVersion(1, 6);
52+
case Triple::DXILSubArch_v1_7:
53+
return DXILVersion(1, 7);
54+
case Triple::DXILSubArch_v1_8:
55+
return DXILVersion(1, 8);
56+
default:
57+
return errInvalid("Cannot get DXIL version for arch '%s'",
58+
TT.getArchName().str().c_str());
59+
}
60+
}
61+
62+
Expected<DXILVersion> DXILVersion::readDXIL(Module &M) {
63+
NamedMDNode *DXILVersionMD = M.getNamedMetadata("dx.version");
64+
if (!DXILVersionMD)
65+
return DXILVersion();
66+
67+
if (DXILVersionMD->getNumOperands() != 1)
68+
return errInvalid("dx.version must have one operand");
69+
70+
MDNode *N = DXILVersionMD->getOperand(0);
71+
if (N->getNumOperands() != 2)
72+
return errInvalid("dx.version must have 2 components, not %d",
73+
N->getNumOperands());
74+
75+
const auto *MajorOp = mdconst::dyn_extract<ConstantInt>(N->getOperand(0));
76+
const auto *MinorOp = mdconst::dyn_extract<ConstantInt>(N->getOperand(1));
77+
if (!MajorOp)
78+
return errInvalid("dx.version major version must be an integer");
79+
if (!MinorOp)
80+
return errInvalid("dx.version minor version must be an integer");
81+
82+
return DXILVersion(MajorOp->getZExtValue(), MinorOp->getZExtValue());
83+
}
84+
85+
void DXILVersion::strip(Module &M) {
86+
M.setTargetTriple("dxil-ms-dx");
87+
}
88+
89+
void DXILVersion::embed(Module &M) {
90+
Triple TT(Triple::normalize(M.getTargetTriple()));
91+
SmallString<64> Triple;
92+
raw_svector_ostream OS(Triple);
93+
print(OS);
94+
OS << "-" << TT.getVendorName() << "-" << TT.getOSAndEnvironmentName();
95+
M.setTargetTriple(OS.str());
96+
}
97+
98+
void DXILVersion::stripDXIL(Module &M) {
99+
if (NamedMDNode *V = M.getNamedMetadata("dx.version")) {
100+
V->dropAllReferences();
101+
V->eraseFromParent();
102+
}
103+
}
104+
105+
void DXILVersion::embedDXIL(Module &M) {
106+
LLVMContext &Ctx = M.getContext();
107+
IRBuilder<> B(Ctx);
108+
109+
Metadata *Vals[2];
110+
Vals[0] = ConstantAsMetadata::get(B.getInt32(Major));
111+
Vals[1] = ConstantAsMetadata::get(B.getInt32(Minor));
112+
MDNode *MD = MDNode::get(Ctx, Vals);
113+
114+
NamedMDNode *V = M.getOrInsertNamedMetadata("dx.version");
115+
if (V->getNumOperands())
116+
V->setOperand(0, MD);
117+
else
118+
V->addOperand(MD);
119+
}

llvm/lib/Transforms/Utils/DXILUpgrade.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/IR/Metadata.h"
1212
#include "llvm/IR/Module.h"
1313
#include "llvm/Support/Debug.h"
14+
#include "llvm/Transforms/Utils/DXIL.h"
1415

1516
using namespace llvm;
1617

@@ -33,6 +34,20 @@ static bool handleValVerMetadata(Module &M) {
3334
return true;
3435
}
3536

37+
static bool handleDXILVerMetadata(Module &M) {
38+
auto V = dxil::DXILVersion::readDXIL(M);
39+
if (Error E = V.takeError()) {
40+
report_fatal_error(std::move(E), /*gen_crash_diag=*/false);
41+
}
42+
if (V->empty())
43+
return false;
44+
45+
LLVM_DEBUG(dbgs() << "DXIL: DXIL Version " << *V << "\n");
46+
V->embed(M);
47+
V->stripDXIL(M);
48+
return true;
49+
}
50+
3651
PreservedAnalyses DXILUpgradePass::run(Module &M, ModuleAnalysisManager &AM) {
3752
PreservedAnalyses PA;
3853
// We never add, remove, or change functions here.
@@ -41,6 +56,7 @@ PreservedAnalyses DXILUpgradePass::run(Module &M, ModuleAnalysisManager &AM) {
4156

4257
bool Changed = false;
4358
Changed |= handleValVerMetadata(M);
59+
Changed |= handleDXILVerMetadata(M);
4460

4561
if (!Changed)
4662
return PreservedAnalyses::all();
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: opt -S -dxil-metadata-emit %s | FileCheck %s
2+
target triple = "dxil-pc-shadermodel-vertex"
3+
4+
; CHECK: !dx.version = !{![[DXVER:[0-9]+]]}
5+
; CHECK: ![[DXVER]] = !{i32 1, i32 0}
6+
7+
define void @entry() #0 {
8+
entry:
9+
ret void
10+
}
11+
12+
attributes #0 = { noinline nounwind "hlsl.shader"="vertex" }
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; RUN: opt -passes=dxil-upgrade -S < %s | FileCheck %s
2+
3+
; Ensure that both the dxil version metadata and its operand are removed.
4+
; CHECK: !unrelated_md1 = !{!0}
5+
; CHECK-NOT: !dx.version
6+
; CHECK: !unrelated_md2 = !{!1}
7+
;
8+
; CHECK: !0 = !{i32 1234}
9+
; CHECK-NOT: !{i32 1, i32 7}
10+
; CHECK: !1 = !{i32 4321}
11+
12+
!unrelated_md1 = !{!0}
13+
!dx.version = !{!1}
14+
!unrelated_md2 = !{!2}
15+
16+
!0 = !{i32 1234}
17+
!1 = !{i32 1, i32 7}
18+
!2 = !{i32 4321}

0 commit comments

Comments
 (0)