Skip to content

Commit e56ad22

Browse files
authored
[DirectX] Encapsulate DXILOpLowering's state into a class. NFC
This introduces an anonymous class "OpLowerer" to help with lowering DXIL ops, and moves the DXILOpBuilder there instead of creating a new one for every operation. DXILOpBuilder is also changed to own its IRBuilder, since that makes it simpler to ensure that it isn't misused. Pull Request: #104248
1 parent c8a678b commit e56ad22

File tree

3 files changed

+73
-52
lines changed

3 files changed

+73
-52
lines changed

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "DXILOpBuilder.h"
1313
#include "DXILConstants.h"
14-
#include "llvm/IR/IRBuilder.h"
1514
#include "llvm/IR/Module.h"
1615
#include "llvm/Support/DXILABI.h"
1716
#include "llvm/Support/ErrorHandling.h"
@@ -335,7 +334,7 @@ namespace dxil {
335334
// Triple is well-formed or that the target is supported since these checks
336335
// would have been done at the time the module M is constructed in the earlier
337336
// stages of compilation.
338-
DXILOpBuilder::DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {
337+
DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) {
339338
Triple TT(Triple(M.getTargetTriple()));
340339
DXILVersion = TT.getDXILVersion();
341340
ShaderStage = TT.getEnvironment();
@@ -417,10 +416,10 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
417416

418417
// We need to inject the opcode as the first argument.
419418
SmallVector<Value *> OpArgs;
420-
OpArgs.push_back(B.getInt32(llvm::to_underlying(OpCode)));
419+
OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
421420
OpArgs.append(Args.begin(), Args.end());
422421

423-
return B.CreateCall(DXILFn, OpArgs);
422+
return IRB.CreateCall(DXILFn, OpArgs);
424423
}
425424

426425
CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
#include "DXILConstants.h"
1616
#include "llvm/ADT/SmallVector.h"
17-
#include "llvm/TargetParser/Triple.h"
17+
#include "llvm/IR/IRBuilder.h"
1818
#include "llvm/Support/Error.h"
19+
#include "llvm/TargetParser/Triple.h"
1920

2021
namespace llvm {
2122
class Module;
@@ -29,7 +30,9 @@ namespace dxil {
2930

3031
class DXILOpBuilder {
3132
public:
32-
DXILOpBuilder(Module &M, IRBuilderBase &B);
33+
DXILOpBuilder(Module &M);
34+
35+
IRBuilder<> &getIRB() { return IRB; }
3336

3437
/// Create a call instruction for the given DXIL op. The arguments
3538
/// must be valid for an overload of the operation.
@@ -51,7 +54,7 @@ class DXILOpBuilder {
5154
Type *OverloadType = nullptr);
5255

5356
Module &M;
54-
IRBuilderBase &B;
57+
IRBuilder<> IRB;
5558
VersionTuple DXILVersion;
5659
Triple::EnvironmentType ShaderStage;
5760
};

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -73,67 +73,84 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
7373
return NewOperands;
7474
}
7575

76-
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
77-
IRBuilder<> B(M.getContext());
78-
DXILOpBuilder OpBuilder(M, B);
79-
for (User *U : make_early_inc_range(F.users())) {
80-
CallInst *CI = dyn_cast<CallInst>(U);
81-
if (!CI)
82-
continue;
83-
84-
SmallVector<Value *> Args;
85-
B.SetInsertPoint(CI);
86-
if (isVectorArgExpansion(F)) {
87-
SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
88-
Args.append(NewArgs.begin(), NewArgs.end());
89-
} else
90-
Args.append(CI->arg_begin(), CI->arg_end());
91-
92-
Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args,
93-
F.getReturnType());
94-
if (Error E = OpCallOrErr.takeError()) {
95-
std::string Message(toString(std::move(E)));
96-
DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
97-
CI->getDebugLoc());
98-
M.getContext().diagnose(Diag);
99-
continue;
76+
namespace {
77+
class OpLowerer {
78+
Module &M;
79+
DXILOpBuilder OpBuilder;
80+
81+
public:
82+
OpLowerer(Module &M) : M(M), OpBuilder(M) {}
83+
84+
void replaceFunction(Function &F,
85+
llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
86+
for (User *U : make_early_inc_range(F.users())) {
87+
CallInst *CI = dyn_cast<CallInst>(U);
88+
if (!CI)
89+
continue;
90+
91+
if (Error E = ReplaceCall(CI)) {
92+
std::string Message(toString(std::move(E)));
93+
DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
94+
CI->getDebugLoc());
95+
M.getContext().diagnose(Diag);
96+
continue;
97+
}
10098
}
101-
CallInst *OpCall = *OpCallOrErr;
99+
if (F.user_empty())
100+
F.eraseFromParent();
101+
}
102102

103-
CI->replaceAllUsesWith(OpCall);
104-
CI->eraseFromParent();
103+
void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) {
104+
bool IsVectorArgExpansion = isVectorArgExpansion(F);
105+
replaceFunction(F, [&](CallInst *CI) -> Error {
106+
SmallVector<Value *> Args;
107+
OpBuilder.getIRB().SetInsertPoint(CI);
108+
if (IsVectorArgExpansion) {
109+
SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
110+
Args.append(NewArgs.begin(), NewArgs.end());
111+
} else
112+
Args.append(CI->arg_begin(), CI->arg_end());
113+
114+
Expected<CallInst *> OpCall =
115+
OpBuilder.tryCreateOp(DXILOp, Args, F.getReturnType());
116+
if (Error E = OpCall.takeError())
117+
return E;
118+
119+
CI->replaceAllUsesWith(*OpCall);
120+
CI->eraseFromParent();
121+
return Error::success();
122+
});
105123
}
106-
if (F.user_empty())
107-
F.eraseFromParent();
108-
}
109124

110-
static bool lowerIntrinsics(Module &M) {
111-
bool Updated = false;
125+
bool lowerIntrinsics() {
126+
bool Updated = false;
112127

113-
for (Function &F : make_early_inc_range(M.functions())) {
114-
if (!F.isDeclaration())
115-
continue;
116-
Intrinsic::ID ID = F.getIntrinsicID();
117-
switch (ID) {
118-
default:
119-
continue;
128+
for (Function &F : make_early_inc_range(M.functions())) {
129+
if (!F.isDeclaration())
130+
continue;
131+
Intrinsic::ID ID = F.getIntrinsicID();
132+
switch (ID) {
133+
default:
134+
continue;
120135
#define DXIL_OP_INTRINSIC(OpCode, Intrin) \
121136
case Intrin: \
122-
lowerIntrinsic(OpCode, F, M); \
137+
replaceFunctionWithOp(F, OpCode); \
123138
break;
124139
#include "DXILOperation.inc"
140+
}
141+
Updated = true;
125142
}
126-
Updated = true;
143+
return Updated;
127144
}
128-
return Updated;
129-
}
145+
};
146+
} // namespace
130147

131148
namespace {
132149
/// A pass that transforms external global definitions into declarations.
133150
class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
134151
public:
135152
PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
136-
if (lowerIntrinsics(M))
153+
if (OpLowerer(M).lowerIntrinsics())
137154
return PreservedAnalyses::none();
138155
return PreservedAnalyses::all();
139156
}
@@ -143,7 +160,9 @@ class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
143160
namespace {
144161
class DXILOpLoweringLegacy : public ModulePass {
145162
public:
146-
bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
163+
bool runOnModule(Module &M) override {
164+
return OpLowerer(M).lowerIntrinsics();
165+
}
147166
StringRef getPassName() const override { return "DXIL Op Lowering"; }
148167
DXILOpLoweringLegacy() : ModulePass(ID) {}
149168

0 commit comments

Comments
 (0)