@@ -73,67 +73,84 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
73
73
return NewOperands;
74
74
}
75
75
76
- static void lowerIntrinsic (dxil::OpCode DXILOp, Function &F, Module &M) {
77
- IRBuilder<> B (M.getContext ());
78
- DXILOpBuilder OpBuilder (M, B);
79
- for (User *U : make_early_inc_range (F.users ())) {
80
- CallInst *CI = dyn_cast<CallInst>(U);
81
- if (!CI)
82
- continue ;
83
-
84
- SmallVector<Value *> Args;
85
- B.SetInsertPoint (CI);
86
- if (isVectorArgExpansion (F)) {
87
- SmallVector<Value *> NewArgs = argVectorFlatten (CI, B);
88
- Args.append (NewArgs.begin (), NewArgs.end ());
89
- } else
90
- Args.append (CI->arg_begin (), CI->arg_end ());
91
-
92
- Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp (DXILOp, Args,
93
- F.getReturnType ());
94
- if (Error E = OpCallOrErr.takeError ()) {
95
- std::string Message (toString (std::move (E)));
96
- DiagnosticInfoUnsupported Diag (*CI->getFunction (), Message,
97
- CI->getDebugLoc ());
98
- M.getContext ().diagnose (Diag);
99
- continue ;
76
+ namespace {
77
+ class OpLowerer {
78
+ Module &M;
79
+ DXILOpBuilder OpBuilder;
80
+
81
+ public:
82
+ OpLowerer (Module &M) : M(M), OpBuilder(M) {}
83
+
84
+ void replaceFunction (Function &F,
85
+ llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
86
+ for (User *U : make_early_inc_range (F.users ())) {
87
+ CallInst *CI = dyn_cast<CallInst>(U);
88
+ if (!CI)
89
+ continue ;
90
+
91
+ if (Error E = ReplaceCall (CI)) {
92
+ std::string Message (toString (std::move (E)));
93
+ DiagnosticInfoUnsupported Diag (*CI->getFunction (), Message,
94
+ CI->getDebugLoc ());
95
+ M.getContext ().diagnose (Diag);
96
+ continue ;
97
+ }
100
98
}
101
- CallInst *OpCall = *OpCallOrErr;
99
+ if (F.user_empty ())
100
+ F.eraseFromParent ();
101
+ }
102
102
103
- CI->replaceAllUsesWith (OpCall);
104
- CI->eraseFromParent ();
103
+ void replaceFunctionWithOp (Function &F, dxil::OpCode DXILOp) {
104
+ bool IsVectorArgExpansion = isVectorArgExpansion (F);
105
+ replaceFunction (F, [&](CallInst *CI) -> Error {
106
+ SmallVector<Value *> Args;
107
+ OpBuilder.getIRB ().SetInsertPoint (CI);
108
+ if (IsVectorArgExpansion) {
109
+ SmallVector<Value *> NewArgs = argVectorFlatten (CI, OpBuilder.getIRB ());
110
+ Args.append (NewArgs.begin (), NewArgs.end ());
111
+ } else
112
+ Args.append (CI->arg_begin (), CI->arg_end ());
113
+
114
+ Expected<CallInst *> OpCall =
115
+ OpBuilder.tryCreateOp (DXILOp, Args, F.getReturnType ());
116
+ if (Error E = OpCall.takeError ())
117
+ return E;
118
+
119
+ CI->replaceAllUsesWith (*OpCall);
120
+ CI->eraseFromParent ();
121
+ return Error::success ();
122
+ });
105
123
}
106
- if (F.user_empty ())
107
- F.eraseFromParent ();
108
- }
109
124
110
- static bool lowerIntrinsics (Module &M ) {
111
- bool Updated = false ;
125
+ bool lowerIntrinsics () {
126
+ bool Updated = false ;
112
127
113
- for (Function &F : make_early_inc_range (M.functions ())) {
114
- if (!F.isDeclaration ())
115
- continue ;
116
- Intrinsic::ID ID = F.getIntrinsicID ();
117
- switch (ID) {
118
- default :
119
- continue ;
128
+ for (Function &F : make_early_inc_range (M.functions ())) {
129
+ if (!F.isDeclaration ())
130
+ continue ;
131
+ Intrinsic::ID ID = F.getIntrinsicID ();
132
+ switch (ID) {
133
+ default :
134
+ continue ;
120
135
#define DXIL_OP_INTRINSIC (OpCode, Intrin ) \
121
136
case Intrin: \
122
- lowerIntrinsic (OpCode, F, M); \
137
+ replaceFunctionWithOp ( F, OpCode); \
123
138
break ;
124
139
#include " DXILOperation.inc"
140
+ }
141
+ Updated = true ;
125
142
}
126
- Updated = true ;
143
+ return Updated ;
127
144
}
128
- return Updated ;
129
- }
145
+ } ;
146
+ } // namespace
130
147
131
148
namespace {
132
149
// / A pass that transforms external global definitions into declarations.
133
150
class DXILOpLowering : public PassInfoMixin <DXILOpLowering> {
134
151
public:
135
152
PreservedAnalyses run (Module &M, ModuleAnalysisManager &) {
136
- if (lowerIntrinsics (M ))
153
+ if (OpLowerer (M). lowerIntrinsics ( ))
137
154
return PreservedAnalyses::none ();
138
155
return PreservedAnalyses::all ();
139
156
}
@@ -143,7 +160,9 @@ class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
143
160
namespace {
144
161
class DXILOpLoweringLegacy : public ModulePass {
145
162
public:
146
- bool runOnModule (Module &M) override { return lowerIntrinsics (M); }
163
+ bool runOnModule (Module &M) override {
164
+ return OpLowerer (M).lowerIntrinsics ();
165
+ }
147
166
StringRef getPassName () const override { return " DXIL Op Lowering" ; }
148
167
DXILOpLoweringLegacy () : ModulePass(ID) {}
149
168
0 commit comments